Skip to content

Commit

Permalink
Add multi class relation classification task and change nli datasets …
Browse files Browse the repository at this point in the history
…to use it

Signed-off-by: Elron Bandel <elron.bandel@ibm.com>
  • Loading branch information
elronbandel committed Feb 20, 2024
1 parent f7af217 commit 5de25dc
Show file tree
Hide file tree
Showing 38 changed files with 398 additions and 162 deletions.
11 changes: 8 additions & 3 deletions prepare/cards/mnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AddFields,
LoadHF,
MapInstanceValues,
RenameFields,
TaskCard,
)
from src.unitxt.catalog import add_to_catalog
Expand All @@ -13,17 +14,21 @@
preprocess_steps=[
RenameSplits({"validation_matched": "validation"}),
"splitters.small_no_test",
RenameFields(field_to_field={"premise": "text_a", "hypothesis": "text_b"}),
MapInstanceValues(
mappers={"label": {"0": "entailment", "1": "neutral", "2": "contradiction"}}
),
AddFields(
fields={
"choices": ["entailment", "neutral", "contradiction"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
"classes": ["entailment", "neutral", "contradiction"],
}
),
],
task="tasks.nli",
templates="templates.classification.nli.all",
task="tasks.classification.multi_class.relation",
templates="templates.classification.multi_class.relation.all",
)

test_card(card)
Expand Down
19 changes: 8 additions & 11 deletions prepare/cards/qnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@
AddFields,
LoadHF,
MapInstanceValues,
SplitRandomMix,
TaskCard,
)
from src.unitxt.catalog import add_to_catalog
from src.unitxt.operators import RenameFields
from src.unitxt.test_utils.card import test_card

default_splitter = SplitRandomMix(
{"train": "train", "validation": "validation", "test": "test"}
)
add_to_catalog(default_splitter, "splitters.default", overwrite=True)

card = TaskCard(
loader=LoadHF(path="glue", name="qnli"),
preprocess_steps=[
Expand All @@ -23,18 +17,21 @@
),
AddFields(
fields={
"choices": ["entailment", "not entailment"],
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "question",
"text_b_type": "sentence",
}
),
RenameFields(
field_to_field={
"question": "premise",
"sentence": "hypothesis",
"question": "text_a",
"sentence": "text_b",
}
),
],
task="tasks.nli",
templates="templates.classification.nli.all",
task="tasks.classification.multi_class.relation",
templates="templates.classification.multi_class.relation.all",
)

test_card(card)
Expand Down
22 changes: 8 additions & 14 deletions prepare/cards/rte.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from src.unitxt.blocks import (
AddFields,
FormTask,
LoadHF,
MapInstanceValues,
RenameFields,
Expand All @@ -9,14 +8,6 @@
from src.unitxt.catalog import add_to_catalog
from src.unitxt.test_utils.card import test_card

nli_task = FormTask(
inputs=["choices", "premise", "hypothesis"],
outputs=["label"],
metrics=["metrics.accuracy"],
)

add_to_catalog(nli_task, "tasks.nli", overwrite=True)

card = TaskCard(
loader=LoadHF(path="glue", name="rte"),
preprocess_steps=[
Expand All @@ -26,18 +17,21 @@
),
AddFields(
fields={
"choices": ["entailment", "not entailment"],
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
}
),
RenameFields(
field_to_field={
"sentence1": "premise",
"sentence2": "hypothesis",
"sentence1": "text_a",
"sentence2": "text_b",
}
),
],
task="tasks.nli",
templates="templates.classification.nli.all",
task="tasks.classification.multi_class.relation",
templates="templates.classification.multi_class.relation.all",
)

test_card(card)
Expand Down
13 changes: 8 additions & 5 deletions prepare/cards/wnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@
"splitters.small_no_test",
RenameFields(
field_to_field={
"sentence1": "premise",
"sentence2": "hypothesis",
"sentence1": "text_a",
"sentence2": "text_b",
}
),
MapInstanceValues(
mappers={"label": {"0": "entailment", "1": "not entailment"}}
),
AddFields(
fields={
"choices": ["entailment", "not entailment"],
"classes": ["entailment", "not entailment"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
}
),
],
task="tasks.nli",
templates="templates.classification.nli.all",
task="tasks.classification.multi_class.relation",
templates="templates.classification.multi_class.relation.all",
)

test_card(card)
Expand Down
13 changes: 9 additions & 4 deletions prepare/cards/xnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AddFields,
LoadHF,
MapInstanceValues,
RenameFields,
TaskCard,
)
from src.unitxt.catalog import add_to_catalog
Expand Down Expand Up @@ -33,20 +34,24 @@
preprocess_steps=[
RenameSplits({"validation_matched": "validation"}),
"splitters.small_no_test",
RenameFields(field_to_field={"premise": "text_a", "hypothesis": "text_b"}),
MapInstanceValues(
mappers={
"label": {"0": "entailment", "1": "neutral", "2": "contradiction"}
}
),
AddFields(
fields={
"choices": ["entailment", "neutral", "contradiction"],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
"classes": ["entailment", "neutral", "contradiction"],
}
),
],
task="tasks.nli",
templates="templates.classification.nli.all",
task="tasks.classification.multi_class.relation",
templates="templates.classification.multi_class.relation.all",
)
if lang == lang[0]:
if lang == langs[0]:
test_card(card)
add_to_catalog(card, f"cards.xnli.{lang}", overwrite=True)
18 changes: 18 additions & 0 deletions prepare/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,21 @@
"tasks.classification.multi_class",
overwrite=True,
)

add_to_catalog(
FormTask(
inputs=[
"text_a",
"text_a_type",
"text_b",
"text_b_type",
"classes",
"type_of_relation",
],
outputs=["label"],
metrics=["metrics.f1_micro", "metrics.accuracy", "metrics.f1_macro"],
augmentable_inputs=["text_a", "text_b"],
),
"tasks.classification.multi_class.relation",
overwrite=True,
)
36 changes: 36 additions & 0 deletions prepare/templates/classification/multi_class/relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from src.unitxt.catalog import add_to_catalog
from src.unitxt.templates import InputOutputTemplate, TemplatesList

add_to_catalog(
InputOutputTemplate(
input_format="{text_a_type}: {text_a}, {text_b_type}: {text_b}",
output_format="{label}",
target_prefix="The {type_of_relation} class is ",
instruction="Given a {text_a_type} and {text_b_type} classify the {type_of_relation} of the {text_b_type} to one of {classes}.",
postprocessors=[
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
),
"templates.classification.multi_class.relation.default",
overwrite=True,
)

add_to_catalog(
InputOutputTemplate(
input_format="Given this {text_a_type}: {text_a}, classify if this {text_b_type}: {text_b} is {classes}.",
output_format="{label}",
postprocessors=[
"processors.take_first_non_empty_line",
"processors.lower_case_till_punc",
],
),
"templates.classification.multi_class.relation.simple",
overwrite=True,
)

add_to_catalog(
TemplatesList(["templates.classification.multi_class.relation.default"]),
"templates.classification.multi_class.relation.all",
overwrite=True,
)
25 changes: 0 additions & 25 deletions prepare/templates/classification/nli/templates.py

This file was deleted.

16 changes: 13 additions & 3 deletions src/unitxt/catalog/cards/mnli.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
}
},
"splitters.small_no_test",
{
"type": "rename_fields",
"field_to_field": {
"premise": "text_a",
"hypothesis": "text_b"
}
},
{
"type": "map_instance_values",
"mappers": {
Expand All @@ -26,14 +33,17 @@
{
"type": "add_fields",
"fields": {
"choices": [
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis",
"classes": [
"entailment",
"neutral",
"contradiction"
]
}
}
],
"task": "tasks.nli",
"templates": "templates.classification.nli.all"
"task": "tasks.classification.multi_class.relation",
"templates": "templates.classification.multi_class.relation.all"
}
15 changes: 9 additions & 6 deletions src/unitxt/catalog/cards/qnli.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@
{
"type": "add_fields",
"fields": {
"choices": [
"classes": [
"entailment",
"not entailment"
]
],
"type_of_relation": "entailment",
"text_a_type": "question",
"text_b_type": "sentence"
}
},
{
"type": "rename_fields",
"field_to_field": {
"question": "premise",
"sentence": "hypothesis"
"question": "text_a",
"sentence": "text_b"
}
}
],
"task": "tasks.nli",
"templates": "templates.classification.nli.all"
"task": "tasks.classification.multi_class.relation",
"templates": "templates.classification.multi_class.relation.all"
}
15 changes: 9 additions & 6 deletions src/unitxt/catalog/cards/rte.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@
{
"type": "add_fields",
"fields": {
"choices": [
"classes": [
"entailment",
"not entailment"
]
],
"type_of_relation": "entailment",
"text_a_type": "premise",
"text_b_type": "hypothesis"
}
},
{
"type": "rename_fields",
"field_to_field": {
"sentence1": "premise",
"sentence2": "hypothesis"
"sentence1": "text_a",
"sentence2": "text_b"
}
}
],
"task": "tasks.nli",
"templates": "templates.classification.nli.all"
"task": "tasks.classification.multi_class.relation",
"templates": "templates.classification.multi_class.relation.all"
}
Loading

0 comments on commit 5de25dc

Please sign in to comment.