Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align I/O with Inference API #99

Merged
merged 26 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d20528a
Fix `task` type-hint and remove extra space in `logging`
alvarobartt Nov 12, 2024
be146c8
Align `transformers` and `diffusers` inputs with Inference API
alvarobartt Nov 12, 2024
0b61436
Remove duplicated `sentencepiece` extra requirement
alvarobartt Nov 12, 2024
49254e9
Remove `pipeline.task` check for `sentence-transformers`
alvarobartt Nov 13, 2024
b45c40a
Add `warning` and `pop` unsupported parameters
alvarobartt Nov 15, 2024
b9dec32
Fix `sentence-transformers` pipeline type-hints
alvarobartt Nov 15, 2024
77c2bb2
Update `sentence-ranking` type-hints
alvarobartt Nov 15, 2024
9b4fc67
Add missing type-hints and clear code a bit
alvarobartt Nov 15, 2024
c1d519a
Fix failing `sentence-transformers` tests due to input parsing
alvarobartt Nov 15, 2024
7f0d84d
Fix "table-question-answering" payload check
alvarobartt Nov 15, 2024
307b27f
Fix "zero-shot-classification" payload check
alvarobartt Nov 15, 2024
d3d2b5e
Check that payload is `dict` in advance
alvarobartt Nov 15, 2024
64cbeb1
Fix `HuggingFaceHandler` errors and checks
alvarobartt Nov 15, 2024
8cbd4be
Fix `sentence-transformers` pipelines as those don't have parameters
alvarobartt Nov 15, 2024
0053e97
Fix `INPUT` to `input_data` fixture
alvarobartt Nov 15, 2024
b9dbf58
Fix quality in `tests/unit/test_handler.py`
alvarobartt Nov 15, 2024
d764e44
Make `parameters` default to empty dict instead of None
alvarobartt Nov 18, 2024
5fbe5af
Add note on `token-classification` / `ner` task
alvarobartt Dec 2, 2024
21ab873
Update `version` in `setup.py`
alvarobartt Dec 2, 2024
7a225e2
Fix `generate_kwargs` payload handling for text2text-based tasks
alvarobartt Dec 3, 2024
cd9ebe7
Fix `generate_kwargs` handling to move to flatten first-level dict
alvarobartt Dec 3, 2024
280101d
Update `generate_kwargs` handling as sometimes required
alvarobartt Dec 3, 2024
42cd852
Remove `generate` from supported generation kwargs key names
alvarobartt Dec 4, 2024
01cd7a8
Update `SentenceRankingPipeline` to handle `query`-`texts` pipelines
alvarobartt Dec 4, 2024
4ffcdfd
Update typing and fix `sentence-transformers` tests
alvarobartt Dec 4, 2024
9d87331
Upgrade `transformers`, `sentence-transformers` and `peft` dependencies
alvarobartt Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
# We don't declare our dependency on transformers here because we build with
# different packages for different variants

VERSION = "0.5.2"
VERSION = "0.5.3"

# Ubuntu packages
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
# libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg

install_requires = [
"transformers[sklearn,sentencepiece,audio,vision,sentencepiece]==4.46.1",
"transformers[sklearn,sentencepiece,audio,vision]==4.46.1",
"huggingface_hub[hf_transfer]==0.26.2",
# vision
"Pillow",
Expand Down
28 changes: 19 additions & 9 deletions src/huggingface_inference_toolkit/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def is_diffusers_available():


class IEAutoPipelineForText2Image:
def __init__(
self, model_dir: str, device: Union[str, None] = None, **kwargs
): # needs "cuda" for GPU
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs): # needs "cuda" for GPU
dtype = torch.float32
if device == "cuda":
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
Expand All @@ -36,9 +34,7 @@ def __init__(
# try to use DPMSolverMultistepScheduler
if isinstance(self.pipeline, StableDiffusionPipeline):
try:
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipeline.scheduler.config
)
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
except Exception:
pass

Expand All @@ -47,6 +43,13 @@ def __call__(
prompt,
**kwargs,
):
if "prompt" in kwargs:
logger.warning(
"prompt has been provided twice, both via arg and kwargs, so the `prompt` arg will be used "
"instead, and the `prompt` in kwargs will be discarded."
)
kwargs.pop("prompt")

# diffusers doesn't support seed but rather the generator kwarg
# see: https://github.com/huggingface/api-inference-community/blob/8e577e2d60957959ba02f474b2913d84a9086b82/docker_images/diffusers/app/pipelines/text_to_image.py#L172-L176
if "seed" in kwargs:
Expand All @@ -58,9 +61,16 @@ def __call__(
# TODO: add support for more images (Reason is correct output)
if "num_images_per_prompt" in kwargs:
kwargs.pop("num_images_per_prompt")
logger.warning(
"Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1."
)
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")

if "target_size" in kwargs:
kwargs["height"] = kwargs["target_size"].pop("height", None)
kwargs["width"] = kwargs["target_size"].pop("width", None)
kwargs.pop("target_size")

if "output_type" in kwargs and kwargs["output_type"] != "pil":
kwargs.pop("output_type")
logger.warning("The `output_type` cannot be modified, and PIL will be used by default instead.")

# Call pipeline with parameters
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
Expand Down
111 changes: 82 additions & 29 deletions src/huggingface_inference_toolkit/handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Literal, Optional, Union

from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
from huggingface_inference_toolkit.utils import (
check_and_register_custom_pipeline_from_directory,
get_pipeline,
Expand All @@ -12,34 +13,87 @@
class HuggingFaceHandler:
"""
A Default Hugging Face Inference Handler which works with all
transformers pipelines, Sentence Transformers and Optimum.
Transformers, Diffusers, Sentence Transformers and Optimum pipelines.
"""

def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
def __init__(
self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt"
) -> None:
self.pipeline = get_pipeline(
model_dir=model_dir,
task=task,
model_dir=model_dir, # type: ignore
task=task, # type: ignore
framework=framework,
trust_remote_code=HF_TRUST_REMOTE_CODE,
)

def __call__(self, data):
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an inference request with input data and makes a prediction.
Args:
:data: (obj): the raw request body data.
:return: prediction output
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)

# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
# postprocess the prediction
return prediction
parameters = data.pop("parameters", {})

# sentence transformers pipelines do not have the `task` arg
if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()):
return self.pipeline(**inputs) if isinstance(inputs, dict) else self.pipeline(inputs) # type: ignore

if self.pipeline.task == "question-answering":
if not isinstance(inputs, dict):
raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.")
if not all(k in inputs for k in {"question", "context"}):
raise ValueError(
f"{self.pipeline.task} expects `inputs` to be a dict containing both `question` and "
"`context` as the keys, both of them being either a `str` or a `List[str]`."
)

if self.pipeline.task == "table-question-answering":
if not isinstance(inputs, dict):
raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.")
if "question" in inputs:
inputs["query"] = inputs.pop("question")
if not all(k in inputs for k in {"table", "query"}):
raise ValueError(
f"{self.pipeline.task} expects `inputs` to be a dict containing the keys `table` and "
"either `question` or `query`."
)

if self.pipeline.task.__contains__("translation") or self.pipeline.task in {
"text-generation",
"image-to-text",
"automatic-speech-recognition",
"text-to-audio",
"text-to-speech",
}:
# `generate_kwargs` needs to be a dict, `generation_parameters` is here for forward compatibility
if "generation_parameters" in parameters:
parameters["generate_kwargs"] = parameters.pop("generation_parameters")

if self.pipeline.task.__contains__("translation") or self.pipeline.task in {"text-generation"}:
# flatten the values of `generate_kwargs` as it's not supported as is, but via top-level parameters
generate_kwargs = parameters.pop("generate_kwargs", {})
for key, value in generate_kwargs.items():
parameters[key] = value

if self.pipeline.task.__contains__("zero-shot-classification"):
if "candidateLabels" in parameters:
parameters["candidate_labels"] = parameters.pop("candidateLabels")
if not isinstance(inputs, dict):
inputs = {"sequences": inputs}
if "text" in inputs:
inputs["sequences"] = inputs.pop("text")
if not all(k in inputs for k in {"sequences"}) or not all(k in parameters for k in {"candidate_labels"}):
raise ValueError(
f"{self.pipeline.task} expects `inputs` to be either a string or a dict containing the "
"key `text` or `sequences`, and `parameters` to be a dict containing either `candidate_labels` "
"or `candidateLabels`."
)

return (
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
)


class VertexAIHandler(HuggingFaceHandler):
Expand All @@ -48,21 +102,21 @@ class VertexAIHandler(HuggingFaceHandler):
Vertex AI specific logic for inference.
"""

def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
super().__init__(model_dir, task, framework)
def __init__(
self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt"
) -> None:
super().__init__(model_dir=model_dir, task=task, framework=framework)

def __call__(self, data):
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handles an inference request with input data and makes a prediction.
Args:
:data: (obj): the raw request body data.
:return: prediction output
"""
if "instances" not in data:
raise ValueError(
"The request body must contain a key 'instances' with a list of instances."
)
parameters = data.pop("parameters", None)
raise ValueError("The request body must contain a key 'instances' with a list of instances.")
parameters = data.pop("parameters", {})

predictions = []
# iterate over all instances and make predictions
Expand All @@ -74,9 +128,7 @@ def __call__(self, data):
return {"predictions": predictions}


def get_inference_handler_either_custom_or_default_handler(
model_dir: Path, task: Optional[str] = None
):
def get_inference_handler_either_custom_or_default_handler(model_dir: Path, task: Optional[str] = None) -> Any:
"""
Returns the appropriate inference handler based on the given model directory and task.

Expand All @@ -88,9 +140,10 @@ def get_inference_handler_either_custom_or_default_handler(
InferenceHandler: The appropriate inference handler based on the given model directory and task.
"""
custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir)
if custom_pipeline:
if custom_pipeline is not None:
return custom_pipeline
elif os.environ.get("AIP_MODE", None) == "PREDICTION":

if os.environ.get("AIP_MODE", None) == "PREDICTION":
return VertexAIHandler(model_dir=model_dir, task=task)
else:
return HuggingFaceHandler(model_dir=model_dir, task=task)

return HuggingFaceHandler(model_dir=model_dir, task=task)
77 changes: 56 additions & 21 deletions src/huggingface_inference_toolkit/sentence_transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import importlib.util
from typing import Any, Dict, List, Tuple, Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

_sentence_transformers = importlib.util.find_spec("sentence_transformers") is not None

Expand All @@ -12,40 +18,73 @@ def is_sentence_transformers_available():


class SentenceSimilarityPipeline:
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
# `device` needs to be set to "cuda" for GPU
self.model = SentenceTransformer(model_dir, device=device, **kwargs)

def __call__(self, inputs=None):
embeddings1 = self.model.encode(
inputs["source_sentence"], convert_to_tensor=True
)
embeddings2 = self.model.encode(inputs["sentences"], convert_to_tensor=True)
def __call__(self, source_sentence: str, sentences: List[str]) -> Dict[str, float]:
embeddings1 = self.model.encode(source_sentence, convert_to_tensor=True)
embeddings2 = self.model.encode(sentences, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(embeddings1, embeddings2).tolist()[0]
return {"similarities": similarities}


class SentenceEmbeddingPipeline:
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
# `device` needs to be set to "cuda" for GPU
self.model = SentenceTransformer(model_dir, device=device, **kwargs)

def __call__(self, inputs):
embeddings = self.model.encode(inputs).tolist()
def __call__(self, sentences: Union[str, List[str]]) -> Dict[str, List[float]]:
embeddings = self.model.encode(sentences).tolist()
return {"embeddings": embeddings}


class RankingPipeline:
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
class SentenceRankingPipeline:
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
# `device` needs to be set to "cuda" for GPU
self.model = CrossEncoder(model_dir, device=device, **kwargs)

def __call__(self, inputs):
scores = self.model.predict(inputs).tolist()
return {"scores": scores}
def __call__(
self,
sentences: Union[Tuple[str, str], List[str], List[List[str]], List[Tuple[str, str]], None] = None,
query: Union[str, None] = None,
texts: Union[List[str], None] = None,
return_documents: bool = False,
) -> Union[Dict[str, List[float]], List[Dict[Literal["index", "score", "text"], Any]]]:
if all(x is not None for x in [sentences, query, texts]):
raise ValueError(
f"The provided payload contains {sentences=} (i.e. 'inputs'), {query=}, and {texts=}"
" but all of those cannot be provided, you should provide either only 'sentences' i.e. 'inputs'"
" of both 'query' and 'texts' to run the ranking task."
)

if all(x is None for x in [sentences, query, texts]):
raise ValueError(
"No inputs have been provided within the input payload, make sure that the input payload"
" contains either 'sentences' i.e. 'inputs', or both 'query' and 'texts' to run the ranking task."
)

if sentences is not None:
scores = self.model.predict(sentences).tolist()
return {"scores": scores}

if query is None or not isinstance(query, str):
raise ValueError(f"Provided {query=} but a non-empty string should be provided instead.")

if texts is None or not isinstance(texts, list) or not all(isinstance(text, str) for text in texts):
raise ValueError(f"Provided {texts=}, but a list of non-empty strings should be provided instead.")

scores = self.model.rank(query, texts, return_documents=return_documents)
# rename "corpus_id" key to "index" for all scores to match TEI
for score in scores:
score["index"] = score.pop("corpus_id") # type: ignore
return scores # type: ignore


SENTENCE_TRANSFORMERS_TASKS = {
"sentence-similarity": SentenceSimilarityPipeline,
"sentence-embeddings": SentenceEmbeddingPipeline,
"sentence-ranking": RankingPipeline,
"sentence-ranking": SentenceRankingPipeline,
}


Expand All @@ -56,9 +95,5 @@ def get_sentence_transformers_pipeline(task=None, model_dir=None, device=-1, **k
kwargs.pop("framework", None)

if task not in SENTENCE_TRANSFORMERS_TASKS:
raise ValueError(
f"Unknown task {task}. Available tasks are: {', '.join(SENTENCE_TRANSFORMERS_TASKS.keys())}"
)
return SENTENCE_TRANSFORMERS_TASKS[task](
model_dir=model_dir, device=device, **kwargs
)
raise ValueError(f"Unknown task {task}. Available tasks are: {', '.join(SENTENCE_TRANSFORMERS_TASKS.keys())}")
return SENTENCE_TRANSFORMERS_TASKS[task](model_dir=model_dir, device=device, **kwargs)
Loading
Loading