diff --git a/prepare/cards/mnli.py b/prepare/cards/mnli.py index 97ca61375b..f0dfa53d0d 100644 --- a/prepare/cards/mnli.py +++ b/prepare/cards/mnli.py @@ -2,6 +2,7 @@ AddFields, LoadHF, MapInstanceValues, + RenameFields, TaskCard, ) from src.unitxt.catalog import add_to_catalog @@ -13,17 +14,21 @@ preprocess_steps=[ RenameSplits({"validation_matched": "validation"}), "splitters.small_no_test", + RenameFields(field_to_field={"premise": "text_a", "hypothesis": "text_b"}), MapInstanceValues( mappers={"label": {"0": "entailment", "1": "neutral", "2": "contradiction"}} ), AddFields( fields={ - "choices": ["entailment", "neutral", "contradiction"], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": ["entailment", "neutral", "contradiction"], } ), ], - task="tasks.nli", - templates="templates.classification.nli.all", + task="tasks.classification.multi_class.relation", + templates="templates.classification.multi_class.relation.all", ) test_card(card) diff --git a/prepare/cards/qnli.py b/prepare/cards/qnli.py index d50bec4574..4d089ec60b 100644 --- a/prepare/cards/qnli.py +++ b/prepare/cards/qnli.py @@ -2,18 +2,12 @@ AddFields, LoadHF, MapInstanceValues, - SplitRandomMix, TaskCard, ) from src.unitxt.catalog import add_to_catalog from src.unitxt.operators import RenameFields from src.unitxt.test_utils.card import test_card -default_splitter = SplitRandomMix( - {"train": "train", "validation": "validation", "test": "test"} -) -add_to_catalog(default_splitter, "splitters.default", overwrite=True) - card = TaskCard( loader=LoadHF(path="glue", name="qnli"), preprocess_steps=[ @@ -23,18 +17,21 @@ ), AddFields( fields={ - "choices": ["entailment", "not entailment"], + "classes": ["entailment", "not entailment"], + "type_of_relation": "entailment", + "text_a_type": "question", + "text_b_type": "sentence", } ), RenameFields( field_to_field={ - "question": "premise", - "sentence": "hypothesis", + "question": "text_a", + "sentence": "text_b", } ), ], - task="tasks.nli", - templates="templates.classification.nli.all", + task="tasks.classification.multi_class.relation", + templates="templates.classification.multi_class.relation.all", ) test_card(card) diff --git a/prepare/cards/rte.py b/prepare/cards/rte.py index d0d5e175da..e5fb0b287f 100644 --- a/prepare/cards/rte.py +++ b/prepare/cards/rte.py @@ -1,6 +1,5 @@ from src.unitxt.blocks import ( AddFields, - FormTask, LoadHF, MapInstanceValues, RenameFields, @@ -9,14 +8,6 @@ from src.unitxt.catalog import add_to_catalog from src.unitxt.test_utils.card import test_card -nli_task = FormTask( - inputs=["choices", "premise", "hypothesis"], - outputs=["label"], - metrics=["metrics.accuracy"], -) - -add_to_catalog(nli_task, "tasks.nli", overwrite=True) - card = TaskCard( loader=LoadHF(path="glue", name="rte"), preprocess_steps=[ @@ -26,18 +17,21 @@ ), AddFields( fields={ - "choices": ["entailment", "not entailment"], + "classes": ["entailment", "not entailment"], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", } ), RenameFields( field_to_field={ - "sentence1": "premise", - "sentence2": "hypothesis", + "sentence1": "text_a", + "sentence2": "text_b", } ), ], - task="tasks.nli", - templates="templates.classification.nli.all", + task="tasks.classification.multi_class.relation", + templates="templates.classification.multi_class.relation.all", ) test_card(card) diff --git a/prepare/cards/wnli.py b/prepare/cards/wnli.py index 9881fda318..7cb6a0f174 100644 --- a/prepare/cards/wnli.py +++ b/prepare/cards/wnli.py @@ -14,8 +14,8 @@ "splitters.small_no_test", RenameFields( field_to_field={ - "sentence1": "premise", - "sentence2": "hypothesis", + "sentence1": "text_a", + "sentence2": "text_b", } ), MapInstanceValues( @@ -23,12 +23,15 @@ ), AddFields( fields={ - "choices": ["entailment", "not entailment"], + "classes": ["entailment", "not entailment"], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", } ), ], - task="tasks.nli", - templates="templates.classification.nli.all", + task="tasks.classification.multi_class.relation", + templates="templates.classification.multi_class.relation.all", ) test_card(card) diff --git a/prepare/cards/xnli.py b/prepare/cards/xnli.py index 7b73cb6183..a03616ed44 100644 --- a/prepare/cards/xnli.py +++ b/prepare/cards/xnli.py @@ -2,6 +2,7 @@ AddFields, LoadHF, MapInstanceValues, + RenameFields, TaskCard, ) from src.unitxt.catalog import add_to_catalog @@ -33,6 +34,7 @@ preprocess_steps=[ RenameSplits({"validation_matched": "validation"}), "splitters.small_no_test", + RenameFields(field_to_field={"premise": "text_a", "hypothesis": "text_b"}), MapInstanceValues( mappers={ "label": {"0": "entailment", "1": "neutral", "2": "contradiction"} @@ -40,13 +42,16 @@ ), AddFields( fields={ - "choices": ["entailment", "neutral", "contradiction"], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": ["entailment", "neutral", "contradiction"], } ), ], - task="tasks.nli", - templates="templates.classification.nli.all", + task="tasks.classification.multi_class.relation", + templates="templates.classification.multi_class.relation.all", ) - if lang == lang[0]: + if lang == langs[0]: test_card(card) add_to_catalog(card, f"cards.xnli.{lang}", overwrite=True) diff --git a/prepare/tasks/classification.py b/prepare/tasks/classification.py index 9aad403266..3a05333af1 100644 --- a/prepare/tasks/classification.py +++ b/prepare/tasks/classification.py @@ -41,3 +41,21 @@ "tasks.classification.multi_class", overwrite=True, ) + +add_to_catalog( + FormTask( + inputs=[ + "text_a", + "text_a_type", + "text_b", + "text_b_type", + "classes", + "type_of_relation", + ], + outputs=["label"], + metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"], + augmentable_inputs=["text_a", "text_b"], + ), + "tasks.classification.multi_class.relation", + overwrite=True, +) diff --git a/prepare/templates/classification/multi_class/relation.py b/prepare/templates/classification/multi_class/relation.py new file mode 100644 index 0000000000..91cebf68ae --- /dev/null +++ b/prepare/templates/classification/multi_class/relation.py @@ -0,0 +1,36 @@ +from src.unitxt.catalog import add_to_catalog +from src.unitxt.templates import InputOutputTemplate, TemplatesList + +add_to_catalog( + InputOutputTemplate( + input_format="{text_a_type}: {text_a}, {text_b_type}: {text_b}", + output_format="{label}", + target_prefix="The {type_of_relation} class is ", + instruction="Given a {text_a_type} and {text_b_type} classify the {type_of_relation} of the {text_b_type} to one of {classes}.", + postprocessors=[ + "processors.take_first_non_empty_line", + "processors.lower_case_till_punc", + ], + ), + "templates.classification.multi_class.relation.default", + overwrite=True, +) + +add_to_catalog( + InputOutputTemplate( + input_format="Given this {text_a_type}: {text_a}, classify if this {text_b_type}: {text_b} is {classes}.", + output_format="{label}", + postprocessors=[ + "processors.take_first_non_empty_line", + "processors.lower_case_till_punc", + ], + ), + "templates.classification.multi_class.relation.simple", + overwrite=True, +) + +add_to_catalog( + TemplatesList(["templates.classification.multi_class.relation.default"]), + "templates.classification.multi_class.relation.all", + overwrite=True, +) diff --git a/prepare/templates/classification/nli/templates.py b/prepare/templates/classification/nli/templates.py deleted file mode 100644 index 7be31765fe..0000000000 --- a/prepare/templates/classification/nli/templates.py +++ /dev/null @@ -1,25 +0,0 @@ -from src.unitxt.catalog import add_to_catalog -from src.unitxt.templates import InputOutputTemplate, TemplatesList - -add_to_catalog( - InputOutputTemplate( - input_format="Given this sentence: {premise}, classify if this sentence: {hypothesis} is {choices}.", - output_format="{label}", - postprocessors=[ - "processors.take_first_non_empty_line", - "processors.lower_case_till_punc", - ], - ), - "templates.classification.nli.simple", - overwrite=True, -) - -add_to_catalog( - TemplatesList( - [ - "templates.classification.nli.simple", - ] - ), - "templates.classification.nli.all", - overwrite=True, -) diff --git a/src/unitxt/catalog/cards/mnli.json b/src/unitxt/catalog/cards/mnli.json index 136f74af62..77ac97e29a 100644 --- a/src/unitxt/catalog/cards/mnli.json +++ b/src/unitxt/catalog/cards/mnli.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/qnli.json b/src/unitxt/catalog/cards/qnli.json index 6f0b698edd..c04dd48e9b 100644 --- a/src/unitxt/catalog/cards/qnli.json +++ b/src/unitxt/catalog/cards/qnli.json @@ -19,20 +19,23 @@ { "type": "add_fields", "fields": { - "choices": [ + "classes": [ "entailment", "not entailment" - ] + ], + "type_of_relation": "entailment", + "text_a_type": "question", + "text_b_type": "sentence" } }, { "type": "rename_fields", "field_to_field": { - "question": "premise", - "sentence": "hypothesis" + "question": "text_a", + "sentence": "text_b" } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/rte.json b/src/unitxt/catalog/cards/rte.json index d640b92d17..c3f9ffdf1a 100644 --- a/src/unitxt/catalog/cards/rte.json +++ b/src/unitxt/catalog/cards/rte.json @@ -19,20 +19,23 @@ { "type": "add_fields", "fields": { - "choices": [ + "classes": [ "entailment", "not entailment" - ] + ], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis" } }, { "type": "rename_fields", "field_to_field": { - "sentence1": "premise", - "sentence2": "hypothesis" + "sentence1": "text_a", + "sentence2": "text_b" } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/wnli.json b/src/unitxt/catalog/cards/wnli.json index 90fa35034f..c577fa1f56 100644 --- a/src/unitxt/catalog/cards/wnli.json +++ b/src/unitxt/catalog/cards/wnli.json @@ -10,8 +10,8 @@ { "type": "rename_fields", "field_to_field": { - "sentence1": "premise", - "sentence2": "hypothesis" + "sentence1": "text_a", + "sentence2": "text_b" } }, { @@ -26,13 +26,16 @@ { "type": "add_fields", "fields": { - "choices": [ + "classes": [ "entailment", "not entailment" - ] + ], + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis" } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/ar.json b/src/unitxt/catalog/cards/xnli/ar.json index 8294e8ec88..2224c0a2c2 100644 --- a/src/unitxt/catalog/cards/xnli/ar.json +++ b/src/unitxt/catalog/cards/xnli/ar.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/bg.json b/src/unitxt/catalog/cards/xnli/bg.json index 6ced327d48..06746644b6 100644 --- a/src/unitxt/catalog/cards/xnli/bg.json +++ b/src/unitxt/catalog/cards/xnli/bg.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/de.json b/src/unitxt/catalog/cards/xnli/de.json index 05da471338..16d3bb558d 100644 --- a/src/unitxt/catalog/cards/xnli/de.json +++ b/src/unitxt/catalog/cards/xnli/de.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/el.json b/src/unitxt/catalog/cards/xnli/el.json index 98aebedd6d..f535f6bdb7 100644 --- a/src/unitxt/catalog/cards/xnli/el.json +++ b/src/unitxt/catalog/cards/xnli/el.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/en.json b/src/unitxt/catalog/cards/xnli/en.json index 26b73493b8..0f93fe581f 100644 --- a/src/unitxt/catalog/cards/xnli/en.json +++ b/src/unitxt/catalog/cards/xnli/en.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/es.json b/src/unitxt/catalog/cards/xnli/es.json index 69b0fb243a..3164ce65a5 100644 --- a/src/unitxt/catalog/cards/xnli/es.json +++ b/src/unitxt/catalog/cards/xnli/es.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/fr.json b/src/unitxt/catalog/cards/xnli/fr.json index c7bf0a6e80..83641462ec 100644 --- a/src/unitxt/catalog/cards/xnli/fr.json +++ b/src/unitxt/catalog/cards/xnli/fr.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/hi.json b/src/unitxt/catalog/cards/xnli/hi.json index 3a0468712d..cdfa0a62df 100644 --- a/src/unitxt/catalog/cards/xnli/hi.json +++ b/src/unitxt/catalog/cards/xnli/hi.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/ru.json b/src/unitxt/catalog/cards/xnli/ru.json index fc27faecce..a471f36a0a 100644 --- a/src/unitxt/catalog/cards/xnli/ru.json +++ b/src/unitxt/catalog/cards/xnli/ru.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/sw.json b/src/unitxt/catalog/cards/xnli/sw.json index 4f1a743dba..8d443b6a9d 100644 --- a/src/unitxt/catalog/cards/xnli/sw.json +++ b/src/unitxt/catalog/cards/xnli/sw.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/th.json b/src/unitxt/catalog/cards/xnli/th.json index 9bb9f436a8..a52fcfb71e 100644 --- a/src/unitxt/catalog/cards/xnli/th.json +++ b/src/unitxt/catalog/cards/xnli/th.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/tr.json b/src/unitxt/catalog/cards/xnli/tr.json index 8e703ba496..95053a7a7e 100644 --- a/src/unitxt/catalog/cards/xnli/tr.json +++ b/src/unitxt/catalog/cards/xnli/tr.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/ur.json b/src/unitxt/catalog/cards/xnli/ur.json index 61c3d9c21a..35f0c9dee1 100644 --- a/src/unitxt/catalog/cards/xnli/ur.json +++ b/src/unitxt/catalog/cards/xnli/ur.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/vi.json b/src/unitxt/catalog/cards/xnli/vi.json index bde93864b4..6aeff54404 100644 --- a/src/unitxt/catalog/cards/xnli/vi.json +++ b/src/unitxt/catalog/cards/xnli/vi.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/cards/xnli/zh.json b/src/unitxt/catalog/cards/xnli/zh.json index a64e89e54f..6357a80765 100644 --- a/src/unitxt/catalog/cards/xnli/zh.json +++ b/src/unitxt/catalog/cards/xnli/zh.json @@ -13,6 +13,13 @@ } }, "splitters.small_no_test", + { + "type": "rename_fields", + "field_to_field": { + "premise": "text_a", + "hypothesis": "text_b" + } + }, { "type": "map_instance_values", "mappers": { @@ -26,7 +33,10 @@ { "type": "add_fields", "fields": { - "choices": [ + "type_of_relation": "entailment", + "text_a_type": "premise", + "text_b_type": "hypothesis", + "classes": [ "entailment", "neutral", "contradiction" @@ -34,6 +44,6 @@ } } ], - "task": "tasks.nli", - "templates": "templates.classification.nli.all" + "task": "tasks.classification.multi_class.relation", + "templates": "templates.classification.multi_class.relation.all" } diff --git a/src/unitxt/catalog/tasks/classification/multi_class/relation.json b/src/unitxt/catalog/tasks/classification/multi_class/relation.json new file mode 100644 index 0000000000..e6e87ca59a --- /dev/null +++ b/src/unitxt/catalog/tasks/classification/multi_class/relation.json @@ -0,0 +1,23 @@ +{ + "type": "form_task", + "inputs": [ + "text_a", + "text_a_type", + "text_b", + "text_b_type", + "classes", + "type_of_relation" + ], + "outputs": [ + "label" + ], + "metrics": [ + "metrics.f1_micro", + "metrics.accuracy", + "metrics.f1_macro" + ], + "augmentable_inputs": [ + "text_a", + "text_b" + ] +} diff --git a/src/unitxt/catalog/tasks/classification/multi_class/two_texts.json b/src/unitxt/catalog/tasks/classification/multi_class/two_texts.json new file mode 100644 index 0000000000..688981d678 --- /dev/null +++ b/src/unitxt/catalog/tasks/classification/multi_class/two_texts.json @@ -0,0 +1,23 @@ +{ + "type": "form_task", + "inputs": [ + "text_a", + "text_a_type", + "text_b", + "text_b_type", + "classes", + "type_of_class" + ], + "outputs": [ + "label" + ], + "metrics": [ + "metrics.f1_micro", + "metrics.accuracy", + "metrics.f1_macro" + ], + "augmentable_inputs": [ + "text_a", + "text_b" + ] +} diff --git a/src/unitxt/catalog/tasks/nli.json b/src/unitxt/catalog/tasks/nli.json deleted file mode 100644 index 08a34f442f..0000000000 --- a/src/unitxt/catalog/tasks/nli.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "type": "form_task", - "inputs": [ - "choices", - "premise", - "hypothesis" - ], - "outputs": [ - "label" - ], - "metrics": [ - "metrics.accuracy" - ] -} diff --git a/src/unitxt/catalog/templates/classification/multi_class/relation/all.json b/src/unitxt/catalog/templates/classification/multi_class/relation/all.json new file mode 100644 index 0000000000..50f4ff73b9 --- /dev/null +++ b/src/unitxt/catalog/templates/classification/multi_class/relation/all.json @@ -0,0 +1,6 @@ +{ + "type": "templates_list", + "items": [ + "templates.classification.multi_class.relation.default" + ] +} diff --git a/src/unitxt/catalog/templates/classification/multi_class/relation/default.json b/src/unitxt/catalog/templates/classification/multi_class/relation/default.json new file mode 100644 index 0000000000..0571e5ed1f --- /dev/null +++ b/src/unitxt/catalog/templates/classification/multi_class/relation/default.json @@ -0,0 +1,11 @@ +{ + "type": "input_output_template", + "input_format": "{text_a_type}: {text_a}, {text_b_type}: {text_b}", + "output_format": "{label}", + "target_prefix": "The {type_of_relation} class is ", + "instruction": "Given a {text_a_type} and {text_b_type} classify the {type_of_relation} of the {text_b_type} to one of {classes}.", + "postprocessors": [ + "processors.take_first_non_empty_line", + "processors.lower_case_till_punc" + ] +} diff --git a/src/unitxt/catalog/templates/classification/nli/simple.json b/src/unitxt/catalog/templates/classification/multi_class/relation/simple.json similarity index 62% rename from src/unitxt/catalog/templates/classification/nli/simple.json rename to src/unitxt/catalog/templates/classification/multi_class/relation/simple.json index 58910e57a0..f6ebecaaa4 100644 --- a/src/unitxt/catalog/templates/classification/nli/simple.json +++ b/src/unitxt/catalog/templates/classification/multi_class/relation/simple.json @@ -1,6 +1,6 @@ { "type": "input_output_template", - "input_format": "Given this sentence: {premise}, classify if this sentence: {hypothesis} is {choices}.", + "input_format": "Given this {text_a_type}: {text_a}, classify if this {text_b_type}: {text_b} is {classes}.", "output_format": "{label}", "postprocessors": [ "processors.take_first_non_empty_line", diff --git a/src/unitxt/catalog/templates/classification/multi_label/default.json b/src/unitxt/catalog/templates/classification/multi_label/default.json index 73419de5ad..713d44abfa 100644 --- a/src/unitxt/catalog/templates/classification/multi_label/default.json +++ b/src/unitxt/catalog/templates/classification/multi_label/default.json @@ -6,6 +6,7 @@ "postprocessors": [ "processors.take_first_non_empty_line", "processors.lower_case", - "processors.to_list_by_comma" + "processors.to_list_by_comma", + "processors.remove_none_from_list" ] } diff --git a/src/unitxt/catalog/templates/classification/nli/all.json b/src/unitxt/catalog/templates/classification/nli/all.json deleted file mode 100644 index 9aad15b7c2..0000000000 --- a/src/unitxt/catalog/templates/classification/nli/all.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "type": "templates_list", - "items": [ - "templates.classification.nli.simple" - ] -} diff --git a/src/unitxt/catalog/templates/classification/nli/simple_with_instruction_model_llama.json b/src/unitxt/catalog/templates/classification/nli/simple_with_instruction_model_llama.json deleted file mode 100644 index 06e9cd52c0..0000000000 --- a/src/unitxt/catalog/templates/classification/nli/simple_with_instruction_model_llama.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "type": "input_output_template", - "input_format": "Given this sentence: {premise}, classify if this sentence: {hypothesis} is {choices}.", - "output_format": "{label}", - "instruction": "<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t know the answer to aquestion, please don’t share false information.\n<>\n\n\n\n", - "postprocessors": [ - "processors.take_first_non_empty_line", - "processors.lower_case_till_punc" - ] -} diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index 38b1291fc3..fa4f7ea0de 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -52,7 +52,7 @@ def process( "target": target, "references": references, "instruction": self.instruction.format(**inputs), - "target_prefix": self.target_prefix, + "target_prefix": self.target_prefix.format(**inputs), } @abstractmethod diff --git a/tests/library/test_recipe.py b/tests/library/test_recipe.py index 2d913c92d2..9dea9d342f 100644 --- a/tests/library/test_recipe.py +++ b/tests/library/test_recipe.py @@ -374,7 +374,7 @@ def test_recipe_with_hf_with_twice_the_same_instance_demos(self): d = load_dataset( dataset_file, - "type=standard_recipe_with_indexes,card=cards.wnli,template=templates.classification.nli.simple,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=5", + "type=standard_recipe_with_indexes,card=cards.wnli,template=templates.classification.multi_class.relation.default,system_prompt=system_prompts.models.llama,demos_pool_size=5,num_demos=5", streaming=True, )