Skip to content

Commit

Permalink
fix quality
Browse files Browse the repository at this point in the history
  • Loading branch information
philschmid committed Jun 27, 2023
1 parent c80695c commit 6a368d4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/sagemaker_huggingface_inference_toolkit/optimum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_input_shapes(model_dir):
logger.info(
f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}"
)
except:
except Exception:
input_shapes_available = False

# return input shapes if available
Expand All @@ -64,7 +64,7 @@ def get_input_shapes(model_dir):

def get_optimum_neuron_pipeline(task, model_dir):
"""Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised."""
from optimum.neuron.pipelines import pipeline, NEURONX_SUPPORTED_TASKS
from optimum.neuron.pipelines import NEURONX_SUPPORTED_TASKS, pipeline
from optimum.neuron.utils import NEURON_FILE_NAME

# check task support
Expand All @@ -78,9 +78,7 @@ def get_optimum_neuron_pipeline(task, model_dir):
if NEURON_FILE_NAME in os.listdir(model_dir):
export = False
if export:
logger.info(
f"Model is not converted. Checking if required environment variables are set and converting model."
)
logger.info("Model is not converted. Checking if required environment variables are set and converting model.")

# get static input shapes to run inference
input_shapes = get_input_shapes(model_dir)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_optimum_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
@require_inferentia
def test_not_supported_task():
os.environ["HF_TASK"] = "not-supported-task"
with pytest.raises(Exception) as e:
with pytest.raises(Exception):
get_optimum_neuron_pipeline(task=TASK, model_dir=os.getcwd())


Expand Down

0 comments on commit 6a368d4

Please sign in to comment.