Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Support sentence transformers clip #495

Merged
merged 13 commits into from
Feb 28, 2024
1 change: 0 additions & 1 deletion benchmark/text-generation/llama2-13b.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from tempfile import TemporaryDirectory

from transformers import AutoTokenizer
Expand Down
87 changes: 82 additions & 5 deletions docs/source/tutorials/sentence_transformers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ limitations under the License.
-->
# Sentence Transformers on AWS Inferentia with Optimum Neuron

## Text Models

_There is a notebook version of that tutorial [here](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/sentence-transformers/getting-started.ipynb)._

This guide explains how to compile, load, and use [Sentence Transformers (SBERT)](https://www.sbert.net/) models on AWS Inferentia2 with Optimum Neuron, enabling efficient calculation of embeddings. Sentence Transformers are powerful models for generating sentence embeddings. You can use this Sentence Transformers to compute sentence / text embeddings for more than 100 languages. These embeddings can then be compared e.g. with cosine-similarity to find sentences with a similar meaning. This can be useful for semantic textual similarity, semantic search, or paraphrase mining.

_Note: Currently only text models are supported, we are working on vision support for CLIP._


## Convert Sentence Transformers model to AWS Inferentia2
### Convert Sentence Transformers model to AWS Inferentia2

First, you need to convert your Sentence Transformers model to a format compatible with AWS Inferentia2. You can compile Sentence Transformers models with Optimum Neuron using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Below you will find an example for both approaches. We have to make sure `sentence-transformers` is installed. Thats only needed for exporting the model.

Expand Down Expand Up @@ -52,7 +52,7 @@ Here we will use the `optimum-cli` to convert the model. Similar to the `NeuronM
optimum-cli export neuron -m BAAI/bge-small-en-v1.5 --library-name sentence_transformers --sequence_length 384 --batch_size 1 --task feature-extraction bge_emb_inf2/
```

## Load compiled Sentence Transformers model and run inference
### Load compiled Sentence Transformers model and run inference

Once we have a compiled Sentence Transformers model, which we either exported ourselves or is available on the Hugging Face Hub, we can load it and run inference. For loading the model we can use the `NeuronModelForSentenceTransformers` class, which is an abstraction layer for the `SentenceTransformer` class. The `NeuronModelForSentenceTransformers` class will automatically pad the input to the specified `sequence_length` and run inference on AWS Inferentia2.

Expand All @@ -79,6 +79,83 @@ print(f"token embeddings: {token_embeddings.shape}") # torch.Size([1, 7, 384])
print(f"sentence_embedding: {sentence_embedding.shape}") # torch.Size([1, 384])
```
## Production Usage
### Production Usage
For deploying these models in a production environment, refer to the [Amazon SageMaker Blog](https://www.philschmid.de/inferentia2-embeddings).
## CLIP
### Compile CLIP for AWS Inferentia2
You can compile CLIP models with Optimum Neuron either by using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Adopt one approach that you prefer:
* With the Optimum CLI
```bash
optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --text_batch_size 3 --image_batch_size 1 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/
```
* With the `NeuronModelForSentenceTransformers` class
```python
from optimum.neuron import NeuronModelForSentenceTransformers
model_id = "sentence-transformers/clip-ViT-B-32"

# configs for compiling model
input_shapes = {
"num_channels": 3,
"height": 224,
"width": 224,
"text_batch_size": 3,
"image_batch_size": 1,
"sequence_length": 64,
}

emb_model = NeuronModelForSentenceTransformers.from_pretrained(
model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", dynamic_batch_size=False, **input_shapes
)

# Save locally or upload to the HuggingFace Hub
save_directory = "clip_emb/"
emb_model.save_pretrained(save_directory)
```

### Load compiled Sentence Transformers model and run inference

```python
from PIL import Image
from sentence_transformers import util
from transformers import CLIPProcessor

from optimum.neuron import NeuronModelForSentenceTransformers

save_directory = "clip_emb"
emb_model = NeuronModelForSentenceTransformers.from_pretrained(save_directory)

processor = CLIPProcessor.from_pretrained(save_directory)
inputs = processor(
text=["Two dogs in the snow", 'A cat on a table', 'A picture of London at night'], images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True
)

outputs = emb_model(**inputs)


# Compute cosine similarities
cos_scores = util.cos_sim(outputs.image_embeds, outputs.text_embeds)
print(cos_scores)

# tensor([[0.3072, 0.1016, 0.1095]])
```

<Tip>

**Caveat**

Since compiled models with dynamic batching enabled only accept input tensors with the same batch size, we cannot set `dynamic_batch_size=True` if the input texts and images have different batch sizes. And as `NeuronModelForSentenceTransformers` class pads the inputs to the batch sizes (`text_batch_size` and `image_batch_size`) used during the compilation, you could use relatively larger batch sizes during the compilation for flexibility with the trade-off of compute.

eg. if you want to encode 3 or 4 or 5 texts and 1 image, you could set `text_batch_size = 5 = max(3, 4, 5)` and `image_batch_size = 1` during the compilation.

</Tip>
10 changes: 10 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=int,
help=f"Batch size {doc_input}",
)
input_group.add_argument(
"--text_batch_size",
type=int,
help=f"Batch size of the text inputs {doc_input} (Only applied for multi-modal models)",
)
input_group.add_argument(
"--image_batch_size",
type=int,
help=f"Batch size of the vision inputs {doc_input} (Only applied for multi-modal models)",
)
input_group.add_argument(
"--sequence_length",
type=int,
Expand Down
12 changes: 8 additions & 4 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
is_neuron_available,
is_neuronx_available,
)
from ...neuron.utils.misc import maybe_save_preprocessors
from ...neuron.utils.version_utils import check_compiler_compatibility_for_stable_diffusion
from ...utils import is_diffusers_available, logging
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from .base import NeuronDecoderConfig
Expand Down Expand Up @@ -129,9 +129,11 @@ def get_input_shapes_and_config_class(task: str, args: argparse.Namespace) -> Di

def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = {"batch_size", "sequence_length"}
if "clip" in args.get("model", "").lower():
mandatory_axes.update(["num_channels", "width", "height"])
mandatory_axes = {"text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height"}
else:
mandatory_axes = {"batch_size", "sequence_length"}

if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
Expand Down Expand Up @@ -237,6 +239,7 @@ def _get_submodels_and_neuron_configs(
task: str,
output: Path,
library_name: Optional[str] = None,
subfolder: str = "",
dynamic_batch_size: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
Expand Down Expand Up @@ -284,7 +287,7 @@ def _get_submodels_and_neuron_configs(
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
output_model_names = {model_name: "model.neuron"}
models_and_neuron_configs = {model_name: (model, neuron_config)}
maybe_save_preprocessors(model_name_or_path, output)
maybe_save_preprocessors(model_name_or_path, output, src_subfolder=subfolder)
return models_and_neuron_configs, output_model_names


Expand Down Expand Up @@ -425,6 +428,7 @@ def main_export(
task=task,
library_name=library_name,
output=output,
subfolder=subfolder,
dynamic_batch_size=dynamic_batch_size,
model_name_or_path=model_name_or_path,
submodels=submodels,
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __init__(
compiler_type: Optional[str] = None,
compiler_version: Optional[str] = None,
batch_size: Optional[int] = None,
text_batch_size: Optional[int] = None,
image_batch_size: Optional[int] = None,
dynamic_batch_size: bool = False,
sequence_length: Optional[int] = None,
num_choices: Optional[int] = None,
Expand Down Expand Up @@ -176,6 +178,8 @@ def __init__(
# To avoid using **kwargs.
axes_values = {
"batch_size": batch_size,
"text_batch_size": text_batch_size,
"image_batch_size": image_batch_size,
"sequence_length": sequence_length,
"num_choices": num_choices,
"width": width,
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def validate_model_outputs(
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
if config.dynamic_batch_size is True:
if config.dynamic_batch_size is True and "batch_size" in input_shapes:
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
input_shapes["batch_size"] *= 2

# Reference outputs
Expand Down
20 changes: 18 additions & 2 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Model specific Neuron configurations."""


import copy
from typing import TYPE_CHECKING, Dict, List

import torch
Expand All @@ -23,6 +23,7 @@
from ...utils import (
DummyInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
Expand Down Expand Up @@ -276,7 +277,7 @@ def outputs(self) -> List[str]:
class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig):
CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper
ATOL_FOR_VALIDATION = 1e-3
INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height")
INPUT_ARGS = ("text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height")

@property
def outputs(self) -> List[str]:
Expand All @@ -285,6 +286,21 @@ def outputs(self) -> List[str]:
def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
for name, axis_dim in self._axes.items():
self._axes[name] = kwargs.pop(name, axis_dim)

self._validate_mandatory_axes()

other_axes = copy.deepcopy(self._axes)
text_batch_size = other_axes.pop("text_batch_size")
images_batch_size = other_axes.pop("image_batch_size")

return [
DummyTextInputGenerator(self.task, self._normalized_config, batch_size=text_batch_size, **other_axes),
DummyVisionInputGenerator(self.task, self._normalized_config, batch_size=images_batch_size, **other_axes),
]


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UNetNeuronConfig(VisionNeuronConfig):
Expand Down
43 changes: 28 additions & 15 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,28 +218,41 @@ def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
**kwargs,
):
neuron_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
model_type = self.config.neuron["model_type"]
neuron_inputs = {"input_ids": input_ids}
if pixel_values is not None:
neuron_inputs["pixel_values"] = pixel_values
neuron_inputs["attention_mask"] = (
attention_mask # The input order for clip is: input_ids, pixel_values, attention_mask.
)

with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs)
# token_embeddings -> (batch_size, sequencen_len, hidden_size)
token_embeddings = self.remove_padding(
[outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
)[
0
] # Remove padding on batch_size(0), and sequence_length(1)
# sentence_embedding -> (batch_size, hidden_size)
sentence_embedding = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
if "clip" in model_type:
text_embeds = self.remove_padding([outputs[0]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
JingyaHuang marked this conversation as resolved.
Show resolved Hide resolved
image_embeds = self.remove_padding([outputs[1]], dims=[0], indices=[pixel_values.shape[0]])[
0
] # Remove padding on batch_size(0)
return ModelOutput(text_embeds=text_embeds, image_embeds=image_embeds)
else:
# token_embeddings -> (batch_size, sequencen_len, hidden_size)
token_embeddings = self.remove_padding(
[outputs[0]], dims=[0, 1], indices=[input_ids.shape[0], input_ids.shape[1]]
)[
0
] # Remove padding on batch_size(0), and sequence_length(1)
# sentence_embedding -> (batch_size, hidden_size)
sentence_embedding = self.remove_padding([outputs[1]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)

return ModelOutput(token_embeddings=token_embeddings, sentence_embedding=sentence_embedding)
return ModelOutput(token_embeddings=token_embeddings, sentence_embedding=sentence_embedding)


MASKED_LM_EXAMPLE = r"""
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..exporters.neuron.model_configs import * # noqa: F403
from ..exporters.tasks import TasksManager
from ..modeling_base import OptimizedModel
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .utils import (
NEURON_FILE_NAME,
check_if_weights_replacable,
Expand All @@ -40,6 +39,7 @@
store_compilation_config,
)
from .utils.import_utils import is_neuronx_available
from .utils.misc import maybe_load_preprocessors, maybe_save_preprocessors
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version


Expand Down Expand Up @@ -280,7 +280,7 @@ def _export(

input_shapes = {}
for name in neuron_config_constructor.func.get_mandatory_axes_for_task(task):
static_shape = kwargs_shapes.get(name, None) or config.neuron.get("static_" + name, None)
static_shape = kwargs_shapes.get(name, None)
if static_shape is None:
raise AttributeError(
f"Cannot find the value of `{name}` from arguments nor the `config`. `{name}` is mandatory"
Expand Down
Loading
Loading