Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added table serializers operators and add Wikitq table question answering dataset #544

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions prepare/cards/wikitq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from src.unitxt.blocks import (
CopyFields,
IndexedRowMajorTableSerializer,
LoadHF,
TaskCard,
)
from src.unitxt.catalog import add_to_catalog
from src.unitxt.test_utils.card import test_card

card = TaskCard(
loader=LoadHF(path="wikitablequestions"),
preprocess_steps=[
"splitters.small_no_test",
CopyFields(field_to_field=[["answers", "answer"]], use_query=True),
IndexedRowMajorTableSerializer(field_to_field=[["table", "context"]]),
],
task="tasks.qa.contextual.extractive",
templates="templates.qa.contextual.all",
)

test_card(card)
add_to_catalog(card, "cards.wikitq", overwrite=True)
4 changes: 4 additions & 0 deletions src/unitxt/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
)
from .processors import ToString, ToStringStripped
from .recipe import SequentialRecipe
from .serializers import (
IndexedRowMajorTableSerializer,
MarkdownTableSerializer,
)
from .splitters import RandomSampler, SliceSplit, SplitRandomMix, SpreadSplit
from .stream import MultiStream
from .task import FormTask
Expand Down
31 changes: 31 additions & 0 deletions src/unitxt/catalog/cards/wikitq.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"type": "task_card",
"loader": {
"type": "load_hf",
"path": "wikitablequestions"
},
"preprocess_steps": [
"splitters.small_no_test",
{
"type": "copy_fields",
"field_to_field": [
[
"answers",
"answer"
]
],
"use_query": true
},
{
"type": "indexed_row_major_table_serializer",
"field_to_field": [
[
"table",
"context"
]
]
}
],
"task": "tasks.qa.contextual.extractive",
"templates": "templates.qa.contextual.all"
}
130 changes: 130 additions & 0 deletions src/unitxt/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import (
Any,
Dict,
List,
)

from .operators import FieldOperator

"""
TableSerializer converts a given table into a flat sequence with special symbols.
Input table format must be:
{"header": ["col1", "col2"], "rows": [["row11", "row12"], ["row21", "row22"], ["row31", "row32"]]}
Output format varies depending on the chosen serializer. Abstract class at the top defines structure of a typical table serializer that any concrete implementation should follow.
"""


class TableSerializer(ABC, FieldOperator):
# main method to serialize a table
@abstractmethod
def serialize_table(self, table_content: Dict) -> str:
pass

# method to process table header
@abstractmethod
def process_header(self, header: List):
pass

# method to process a table row
@abstractmethod
def process_row(self, row: List, row_index: int):
pass


# Concrete classes implementing table serializers follow..
"""
Indexed Row Major Table Serializer.
Commonly used row major serialization format.
Format: col : col1 | col2 | col 3 row 1 : val1 | val2 | val3 | val4 row 2 : val1 | ...
"""


class IndexedRowMajorTableSerializer(TableSerializer):
def process_value(self, table: Any) -> Any:
table_input = deepcopy(table)
return self.serialize_table(table_content=table_input)

# main method that processes a table
# table_content must be in the presribed input format
def serialize_table(self, table_content: Dict) -> str:
# Extract headers and rows from the dictionary
header = table_content.get("header", [])
rows = table_content.get("rows", [])

assert header and rows, "Incorrect input table format"

# Process table header first
serialized_tbl_str = self.process_header(header) + " "

# Process rows sequentially starting from row 1
for i, row in enumerate(rows, start=1):
serialized_tbl_str += self.process_row(row, row_index=i) + " "

# return serialized table as a string
return serialized_tbl_str.strip()

# serialize header into a string containing the list of column names separated by '|' symbol
def process_header(self, header: List):
return "col : " + " | ".join(header)

# serialize a table row into a string containing the list of cell values separated by '|'
def process_row(self, row: List, row_index: int):
serialized_row_str = ""
row_cell_values = [
str(value) if isinstance(value, (int, float)) else value for value in row
]

serialized_row_str += " | ".join(row_cell_values)

return f"row {row_index} : {serialized_row_str}"


"""
Markdown Table Serializer.
Markdown table format is used in GitHub code primarily.
Format:
|col1|col2|col3|
|---|---|---|
|A|4|1|
|I|2|1|
...
"""


class MarkdownTableSerializer(TableSerializer):
def process_value(self, table: Any) -> Any:
table_input = deepcopy(table)
return self.serialize_table(table_content=table_input)

# main method that serializes a table.
# table_content must be in the presribed input format.
def serialize_table(self, table_content: Dict) -> str:
# Extract headers and rows from the dictionary
header = table_content.get("header", [])
rows = table_content.get("rows", [])

assert header and rows, "Incorrect input table format"

# Process table header first
serialized_tbl_str = self.process_header(header)

# Process rows sequentially starting from row 1
for i, row in enumerate(rows, start=1):
serialized_tbl_str += self.process_row(row, row_index=i)

# return serialized table as a string
return serialized_tbl_str.strip()

# serialize header into a string containing the list of column names
def process_header(self, header: List):
header_str = "|{}|\n".format("|".join(header))
header_str += "|{}|\n".format("|".join(["---"] * len(header)))
return header_str

# serialize a table row into a string containing the list of cell values
def process_row(self, row: List, row_index: int):
row_str = ""
row_str += "|{}|\n".format("|".join(str(cell) for cell in row))
return row_str
75 changes: 75 additions & 0 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest

from src.unitxt.serializers import (
IndexedRowMajorTableSerializer,
MarkdownTableSerializer,
)
from src.unitxt.test_utils.operators import (
check_operator,
)


class TestSerializers(unittest.TestCase):
def test_markdown_tableserializer(self):
inputs = [
{
"table": {
"header": ["name", "age"],
"rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]],
}
}
]

serialized_str = "|name|age|\n|---|---|\n|Alex|26|\n|Raj|34|\n|Donald|39|"

targets = [
{
"table": {
"header": ["name", "age"],
"rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]],
},
"serialized_table": serialized_str,
}
]

check_operator(
operator=MarkdownTableSerializer(
field_to_field={"table": "serialized_table"}
),
inputs=inputs,
targets=targets,
tester=self,
)

def test_indexedrowmajor_tableserializer(self):
inputs = [
{
"table": {
"header": ["name", "age"],
"rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]],
}
}
]

serialized_str = (
"col : name | age row 1 : Alex | 26 row 2 : Raj | 34 row 3 : Donald | 39"
)

targets = [
{
"table": {
"header": ["name", "age"],
"rows": [["Alex", "26"], ["Raj", "34"], ["Donald", "39"]],
},
"serialized_table": serialized_str,
}
]

check_operator(
operator=IndexedRowMajorTableSerializer(
field_to_field={"table": "serialized_table"}
),
inputs=inputs,
targets=targets,
tester=self,
)
Loading