Skip to content

Commit

Permalink
Add support for additional tasks via HuggingFace hub
Browse files Browse the repository at this point in the history
  • Loading branch information
samruds committed Mar 7, 2024
1 parent 80634b3 commit 6a43011
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
extras = {}

# Hugging Face specific dependencies
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.5.1"]
extras["diffusers"] = ["diffusers>=0.23.0"]

# framework specific dependencies
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
extras["torch"] = ["torch>=2.1.0", "torchaudio"]

# TODO: Remove upper bound of TF 2.11 once transformers release contains this fix: https://github.com/huggingface/evaluate/pull/372
extras["tensorflow"] = ["tensorflow>=2.4.0,<2.11"]
Expand All @@ -68,7 +68,7 @@


extras["test"] = [
"pytest",
"pytest<=8.0.0",
"pytest-xdist",
"parameterized",
"psutil",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def load(self, model_dir, context=None):
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
else:
raise ValueError(
f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.",
f"Task not supported via {list(SUPPORTED_TASKS.keys())} or"
"Use inference.py to install unsupported task separately",
403,
)
return hf_pipeline
Expand Down
24 changes: 11 additions & 13 deletions src/sagemaker_huggingface_inference_toolkit/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,19 @@ def infer_task_from_model_architecture(model_config_path: str, architecture_inde
trainend on different tasks https://huggingface.co/facebook/bart-large/blob/main/config.json. Should work for every on Amazon SageMaker fine-tuned model.
It is always recommended to set the task through the env var `TASK`.
"""
with open(model_config_path, "r") as config_file:
config = json.loads(config_file.read())
architecture = config.get("architectures", [None])[architecture_index]

task = None
for arch_options in ARCHITECTURES_2_TASK:
if architecture.endswith(arch_options):
task = ARCHITECTURES_2_TASK[arch_options]

if task is None:
raise ValueError(
f"Task couldn't be inferenced from {architecture}."
f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
"Use env `HF_TASK` to define your task."
)
if "HF_TASK" in os.environ:
return os.environ["HF_TASK"]
else:
with open(model_config_path, "r") as config_file:
config = json.loads(config_file.read())
architecture = config.get("architectures", [None])[architecture_index]

for arch_options in ARCHITECTURES_2_TASK:
if architecture.endswith(arch_options):
task = ARCHITECTURES_2_TASK[arch_options]

# set env to work with
os.environ["HF_TASK"] = task
return task
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_handler_service_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_load(inference_handler):
model_id=MODEL,
model_dir=tmpdirname,
)
os.environ.pop("HF_TASK")
# test with automatic infer
hf_pipeline_without_task = inference_handler.load(storage_folder, context)
assert hf_pipeline_without_task.task == "token-classification"
Expand Down Expand Up @@ -145,7 +146,7 @@ def test_validate_and_initialize_user_module(inference_handler):
prediction = inference_handler.handle([{"body": b""}], CONTEXT)
assert "output" in prediction[0]

assert inference_handler.load({}, CONTEXT) == "model"
assert inference_handler.load({}) == "model"
assert inference_handler.preprocess({}, "", CONTEXT) == "data"
assert inference_handler.predict({}, "model", CONTEXT) == "output"
assert inference_handler.postprocess("output", "", CONTEXT) == "output"
Expand All @@ -161,7 +162,7 @@ def test_validate_and_initialize_user_module_transform_fn():
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
assert (
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
== "output dummy"
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_handler_service_without_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_load(inference_handler):
model_id=MODEL,
model_dir=tmpdirname,
)
os.environ.pop("HF_TASK")
# test with automatic infer
hf_pipeline_without_task = inference_handler.load(storage_folder)
assert hf_pipeline_without_task.task == "token-classification"
Expand Down Expand Up @@ -139,8 +140,7 @@ def test_validate_and_initialize_user_module(inference_handler):
prediction = inference_handler.handle([{"body": b""}], CONTEXT)
assert "output" in prediction[0]

assert inference_handler.load({}) == "model"
assert inference_handler.preprocess({}, "") == "data"
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
assert inference_handler.predict({}, "model") == "output"
assert inference_handler.postprocess("output", "") == "output"

Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ def test_infer_task_from_model_architecture():
assert task == "token-classification"


@require_torch
def test_infer_task_from_model_architecture_from_env_variable():
os.environ["HF_TASK"] = "image-classification"
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
assert task == "image-classification"


@require_torch
def test_wrap_conversation_pipeline():
init_pipeline = pipeline(
Expand Down

0 comments on commit 6a43011

Please sign in to comment.