From 5569f482cde3401474829c6317aab24c5af511f5 Mon Sep 17 00:00:00 2001 From: osanseviero Date: Mon, 14 Jun 2021 17:38:59 +0200 Subject: [PATCH 1/6] Allow batch for feature-extraction --- .../api_inference_community/validation.py | 18 +++++++--- api-inference-community/tests/test_nlp.py | 33 +++++++++++++++++-- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index ed7563e5f7..ac5c1e5eeb 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -124,9 +124,7 @@ class TableQuestionAnsweringInputsCheck(BaseModel): query: str @validator("table") - def all_rows_must_have_same_length( - cls, table: Dict[str, List[str]] - ): + def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]): rows = list(table.values()) n = len(rows[0]) if all(len(x) == n for x in rows): @@ -134,7 +132,6 @@ def all_rows_must_have_same_length( raise ValueError("All rows in the table must be the same length") - PARAMS_MAPPING = { "conversational": SharedGenerationParams, "fill-mask": FillMaskParamsCheck, @@ -151,6 +148,8 @@ def all_rows_must_have_same_length( "table-question-answering": TableQuestionAnsweringInputsCheck, } +BATCH_ENABLED__PIPELINES = ["feature-extraction"] + def check_params(params, tag): if tag in PARAMS_MAPPING: @@ -161,9 +160,18 @@ def check_params(params, tag): def check_inputs(inputs, tag): if tag in INPUTS_MAPPING: INPUTS_MAPPING[tag].parse_obj(inputs) + elif tag in BATCH_ENABLED__PIPELINES: + if isinstance(inputs, list): + if len(inputs) == 0: + raise ValueError( + "The inputs is invalid, at least one input is required" + ) + if not all(isinstance(input, str) for input in inputs): + raise ValueError("The inputs is invalid, we expect a list of strings") + elif not isinstance(inputs, str): + raise ValueError("The inputs is invalid, we expect a string") else: # Some tasks just expect {inputs: "str"}. Such as: - # feature-extraction # fill-mask # text2text-generation # text-classification diff --git a/api-inference-community/tests/test_nlp.py b/api-inference-community/tests/test_nlp.py index 95e4be5e45..b0dab09aff 100644 --- a/api-inference-community/tests/test_nlp.py +++ b/api-inference-community/tests/test_nlp.py @@ -418,8 +418,8 @@ class TextGenerationTestCase(make_text_generation_test_case("text-generation")): pass -class TasksWithOnlyInputStringTestCase(TestCase): - def test_feature_extraction_accept_string_no_params(self): +class FeatureExtractionTestCase(TestCase): + def test_valid_string(self): bpayload = json.dumps({"inputs": "whatever"}).encode("utf-8") normalized_inputs, processed_params = normalize_payload_nlp( bpayload, "feature-extraction" @@ -427,6 +427,35 @@ def test_feature_extraction_accept_string_no_params(self): self.assertEqual(processed_params, {}) self.assertEqual(normalized_inputs, "whatever") + def test_valid_list_of_strings(self): + inputs = ["hugging", "face"] + bpayload = json.dumps({"inputs": inputs}).encode("utf-8") + normalized_inputs, processed_params = normalize_payload_nlp( + bpayload, "feature-extraction" + ) + self.assertEqual(processed_params, {}) + self.assertEqual(normalized_inputs, inputs) + + def test_invalid_list_with_number(self): + inputs = ["hugging", 5] + bpayload = json.dumps({"inputs": inputs}).encode("utf-8") + with self.assertRaises(ValueError): + normalize_payload_nlp(bpayload, "feature-extraction") + + def test_invalid_empty_list(self): + inputs = [] + bpayload = json.dumps({"inputs": inputs}).encode("utf-8") + with self.assertRaises(ValueError): + normalize_payload_nlp(bpayload, "feature-extraction") + + def test_invalid_input_no_string(self): + inputs = 123 + bpayload = json.dumps({"inputs": inputs}).encode("utf-8") + with self.assertRaises(ValueError): + normalize_payload_nlp(bpayload, "feature-extraction") + + +class TasksWithOnlyInputStringTestCase(TestCase): def test_fill_mask_accept_string_no_params(self): bpayload = json.dumps({"inputs": "whatever"}).encode("utf-8") normalized_inputs, processed_params = normalize_payload_nlp( From 9a39918cc2ed317fa5629d525e4403ec8cf2e3bb Mon Sep 17 00:00:00 2001 From: osanseviero Date: Tue, 15 Jun 2021 14:51:48 +0200 Subject: [PATCH 2/6] Add extra underscore --- api-inference-community/api_inference_community/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index ac5c1e5eeb..0b0fbd3c90 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -148,7 +148,7 @@ def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]): "table-question-answering": TableQuestionAnsweringInputsCheck, } -BATCH_ENABLED__PIPELINES = ["feature-extraction"] +BATCH_ENABLED_PIPELINES = ["feature-extraction"] def check_params(params, tag): @@ -160,7 +160,7 @@ def check_params(params, tag): def check_inputs(inputs, tag): if tag in INPUTS_MAPPING: INPUTS_MAPPING[tag].parse_obj(inputs) - elif tag in BATCH_ENABLED__PIPELINES: + elif tag in BATCH_ENABLED_PIPELINES: if isinstance(inputs, list): if len(inputs) == 0: raise ValueError( From c7629e0fd2f66cc975b86b17f1bd427db3611b51 Mon Sep 17 00:00:00 2001 From: osanseviero Date: Tue, 15 Jun 2021 17:08:38 +0200 Subject: [PATCH 3/6] Use validator to do the validation :D --- .../api_inference_community/validation.py | 44 ++++++++++--------- api-inference-community/tests/test_nlp.py | 10 +---- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index 0b0fbd3c90..8b8b7213e5 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -131,6 +131,21 @@ def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]): return table raise ValueError("All rows in the table must be the same length") +class StringOrStringBatchInputCheck(BaseModel): + __root__: Union[List[str], str] + + @validator("__root__") + def input_must_not_be_empty(cls, __root__: Union[List[str], str]): + if isinstance(__root__, list): + if len(__root__) == 0: + raise ValueError( + "The inputs is invalid, at least one input is required" + ) + return __root__ + +class StringInput(BaseModel): + __root__: str + PARAMS_MAPPING = { "conversational": SharedGenerationParams, @@ -144,8 +159,17 @@ def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]): INPUTS_MAPPING = { "conversational": ConversationalInputsCheck, "question-answering": QuestionInputsCheck, + "feature-extraction": StringOrStringBatchInputCheck, "sentence-similarity": SentenceSimilarityInputsCheck, "table-question-answering": TableQuestionAnsweringInputsCheck, + "fill-mask": StringInput, + "summarization": StringInput, + "text2text-generation": StringInput, + "text-generation": StringInput, + "text-classification": StringInput, + "token-classification": StringInput, + "translation": StringInput, + "zero-shot-classification": StringInput, } BATCH_ENABLED_PIPELINES = ["feature-extraction"] @@ -160,26 +184,6 @@ def check_params(params, tag): def check_inputs(inputs, tag): if tag in INPUTS_MAPPING: INPUTS_MAPPING[tag].parse_obj(inputs) - elif tag in BATCH_ENABLED_PIPELINES: - if isinstance(inputs, list): - if len(inputs) == 0: - raise ValueError( - "The inputs is invalid, at least one input is required" - ) - if not all(isinstance(input, str) for input in inputs): - raise ValueError("The inputs is invalid, we expect a list of strings") - elif not isinstance(inputs, str): - raise ValueError("The inputs is invalid, we expect a string") - else: - # Some tasks just expect {inputs: "str"}. Such as: - # fill-mask - # text2text-generation - # text-classification - # text-generation - # token-classification - # translation - if not isinstance(inputs, str): - raise ValueError("The inputs is invalid, we expect a string") return True diff --git a/api-inference-community/tests/test_nlp.py b/api-inference-community/tests/test_nlp.py index b0dab09aff..460736aef2 100644 --- a/api-inference-community/tests/test_nlp.py +++ b/api-inference-community/tests/test_nlp.py @@ -436,8 +436,8 @@ def test_valid_list_of_strings(self): self.assertEqual(processed_params, {}) self.assertEqual(normalized_inputs, inputs) - def test_invalid_list_with_number(self): - inputs = ["hugging", 5] + def test_invalid_list_with_other_type(self): + inputs = ["hugging", [1,2,3]] bpayload = json.dumps({"inputs": inputs}).encode("utf-8") with self.assertRaises(ValueError): normalize_payload_nlp(bpayload, "feature-extraction") @@ -448,12 +448,6 @@ def test_invalid_empty_list(self): with self.assertRaises(ValueError): normalize_payload_nlp(bpayload, "feature-extraction") - def test_invalid_input_no_string(self): - inputs = 123 - bpayload = json.dumps({"inputs": inputs}).encode("utf-8") - with self.assertRaises(ValueError): - normalize_payload_nlp(bpayload, "feature-extraction") - class TasksWithOnlyInputStringTestCase(TestCase): def test_fill_mask_accept_string_no_params(self): From acb0baf2c79271555d339fa3a9882aca2c98814b Mon Sep 17 00:00:00 2001 From: osanseviero Date: Tue, 15 Jun 2021 17:09:29 +0200 Subject: [PATCH 4/6] format --- api-inference-community/api_inference_community/validation.py | 2 ++ api-inference-community/tests/test_nlp.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index 8b8b7213e5..bf50729833 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -131,6 +131,7 @@ def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]): return table raise ValueError("All rows in the table must be the same length") + class StringOrStringBatchInputCheck(BaseModel): __root__: Union[List[str], str] @@ -143,6 +144,7 @@ def input_must_not_be_empty(cls, __root__: Union[List[str], str]): ) return __root__ + class StringInput(BaseModel): __root__: str diff --git a/api-inference-community/tests/test_nlp.py b/api-inference-community/tests/test_nlp.py index 460736aef2..15d4b5a386 100644 --- a/api-inference-community/tests/test_nlp.py +++ b/api-inference-community/tests/test_nlp.py @@ -437,7 +437,7 @@ def test_valid_list_of_strings(self): self.assertEqual(normalized_inputs, inputs) def test_invalid_list_with_other_type(self): - inputs = ["hugging", [1,2,3]] + inputs = ["hugging", [1, 2, 3]] bpayload = json.dumps({"inputs": inputs}).encode("utf-8") with self.assertRaises(ValueError): normalize_payload_nlp(bpayload, "feature-extraction") From cc3f97a49b7665343bdeb2b3c6bebe9dfbc90c58 Mon Sep 17 00:00:00 2001 From: osanseviero Date: Tue, 15 Jun 2021 17:10:39 +0200 Subject: [PATCH 5/6] Improve error message --- api-inference-community/api_inference_community/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index bf50729833..d1aaa00847 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -140,7 +140,7 @@ def input_must_not_be_empty(cls, __root__: Union[List[str], str]): if isinstance(__root__, list): if len(__root__) == 0: raise ValueError( - "The inputs is invalid, at least one input is required" + "The inputs are invalid, at least one input is required" ) return __root__ From 42ed5be60d29f2958ddca8111a1b17f5781ea4ed Mon Sep 17 00:00:00 2001 From: osanseviero Date: Wed, 16 Jun 2021 16:15:21 +0200 Subject: [PATCH 6/6] Make tag validation stricter --- .../api_inference_community/validation.py | 4 +++- api-inference-community/tests/test_nlp.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/api-inference-community/api_inference_community/validation.py b/api-inference-community/api_inference_community/validation.py index d1aaa00847..02f546f668 100644 --- a/api-inference-community/api_inference_community/validation.py +++ b/api-inference-community/api_inference_community/validation.py @@ -186,7 +186,9 @@ def check_params(params, tag): def check_inputs(inputs, tag): if tag in INPUTS_MAPPING: INPUTS_MAPPING[tag].parse_obj(inputs) - return True + return True + else: + raise ValueError(f"{tag} is not a valid pipeline.") def normalize_payload( diff --git a/api-inference-community/tests/test_nlp.py b/api-inference-community/tests/test_nlp.py index 15d4b5a386..2efe531670 100644 --- a/api-inference-community/tests/test_nlp.py +++ b/api-inference-community/tests/test_nlp.py @@ -10,15 +10,23 @@ class ValidationTestCase(TestCase): def test_malformed_input(self): bpayload = b"\xc3\x28" with self.assertRaises(UnicodeDecodeError): - normalize_payload_nlp(bpayload, "tag") + normalize_payload_nlp(bpayload, "question-answering") def test_accept_raw_string_for_backward_compatibility(self): query = "funny cats" bpayload = query.encode("utf-8") - normalized_inputs, processed_params = normalize_payload_nlp(bpayload, "tag") + normalized_inputs, processed_params = normalize_payload_nlp( + bpayload, "translation" + ) self.assertEqual(processed_params, {}) self.assertEqual(normalized_inputs, query) + def test_invalid_tag(self): + query = "funny cats" + bpayload = query.encode("utf-8") + with self.assertRaises(ValueError): + normalize_payload_nlp(bpayload, "invalid-tag") + class QuestionAnsweringValidationTestCase(TestCase): def test_valid_input(self):