Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix to handle final word in word level argmax #426

Merged
merged 8 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions dataprofiler/labelers/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,18 +815,25 @@ def _word_level_argmax(self, data, predictions, label_mapping,
# FORMER DEEPCOPY, SHALLOW AS ONLY INTERNAL
entities_in_sample = list(char_pred)

# Convert to dict for quick look-up
separator_dict = {}
for separator in separators:
separator_dict[separator] = True
# Convert to set for quick look-up
separator_dict = set(separators)

# Iterate over sample
start_idx = 0
label_count = {label_mapping[default_label]: 0}
for idx in range(len(sample)):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would an enumerate not be cleaner?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, but not going to block the PR if @lettergram doesn't get to it prior to merge


# Split on separator
if sample[idx] in separator_dict:
# Split on separator or last sample
is_separator = sample[idx] in separator_dict
is_end = (idx == len(sample)-1 and start_idx > 0)

if not is_separator:
label = entities_in_sample[idx]
if label not in label_count:
label_count[label] = 0
label_count[label] += 1

if is_separator or is_end:

# Find sum of labels over entity
total_label_count = sum(label_count.values())
Expand Down Expand Up @@ -858,14 +865,7 @@ def _word_level_argmax(self, data, predictions, label_mapping,
label_count = {background_label: 0}
if char_pred[idx] == background_label and \
sample[idx] in separator_dict:
continue

# Keep count of labels since start
label = entities_in_sample[idx]
if label not in label_count:
label_count[label] = 0
label_count[label] += 1

continue
word_level_predictions.append(entities_in_sample)

return word_level_predictions
Expand Down
38 changes: 29 additions & 9 deletions dataprofiler/tests/labelers/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,13 +901,15 @@ def test_get_parameters(self):
def test_word_level_argmax(self):

# input data initialization
data = np.array(['this is my test sentence.', 'How nice.'])
data = np.array(['this is my test sentence.', 'How nice.', 'How nice'])
predictions = [
# this is my test sentence.
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2,
3, 3, 1],
# How nice.
[2, 2, 1, 1, 3, 1, 3, 3, 1]
[2, 2, 1, 1, 3, 1, 3, 3, 1],
# How nice
[2, 2, 1, 1, 3, 1, 3, 3]
]
label_mapping = {
'PAD': 0,
Expand All @@ -925,7 +927,9 @@ def test_word_level_argmax(self):
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2,
3, 3, 1],
# How nice.
[2, 2, 1, 1, 3, 1, 3, 3, 1]
[2, 2, 1, 1, 3, 1, 3, 3, 1],
# How nice
[2, 2, 1, 1, 3, 1, 3, 3]
]
output = processor._word_level_argmax(
data, predictions, label_mapping, default_label)
Expand All @@ -938,7 +942,9 @@ def test_word_level_argmax(self):
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 1],
# How nice.
[1, 1, 1, 1, 3, 3, 3, 3, 1]
[1, 1, 1, 1, 3, 3, 3, 3, 1],
# How nice
[1, 1, 1, 1, 3, 3, 3, 3]
]
output = processor._word_level_argmax(
data, predictions, label_mapping, default_label)
Expand All @@ -951,7 +957,9 @@ def test_word_level_argmax(self):
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1],
# How nice.
[1, 1, 1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1],
# How nice
[1, 1, 1, 1, 1, 1, 1, 1]
]
output = processor._word_level_argmax(
data, predictions, label_mapping, default_label)
Expand All @@ -964,21 +972,25 @@ def test_word_level_argmax(self):
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 1],
# How nice.
[2, 2, 2, 1, 3, 3, 3, 3, 1]
[2, 2, 2, 1, 3, 3, 3, 3, 1],
# How nice
[2, 2, 2, 1, 3, 3, 3, 3]
]
output = processor._word_level_argmax(
data, predictions, label_mapping, default_label)
self.assertListEqual(expected_output, output)

def test_convert_to_NER_format(self):
# input data initialization
data = np.array(['this is my test sentence.', 'How nice.'])
data = np.array(['this is my test sentence.', 'How nice.', 'How nice'])
predictions = [
# this is my test sentence.
[1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2,
3, 3, 1],
# How nice.
[2, 2, 1, 1, 3, 1, 3, 3, 1]
[2, 2, 1, 1, 3, 1, 3, 3, 1],
# How nice
[2, 2, 2, 1, 3, 1, 3, 3]
]
label_mapping = {
'PAD': 0,
Expand All @@ -999,7 +1011,12 @@ def test_convert_to_NER_format(self):
[
(0, 2, 'TEST1'),
(4, 5, 'TEST2'),
(6, 8, 'TEST2')],
[
(0, 3, 'TEST1'),
(4, 5, 'TEST2'),
(6, 8, 'TEST2')]

]

output = processor.convert_to_NER_format(
Expand All @@ -1017,7 +1034,10 @@ def test_convert_to_NER_format(self):
[
( 2, 4, 'UNKNOWN'),
( 5, 6, 'UNKNOWN'),
( 8, 9, 'UNKNOWN')]
( 8, 9, 'UNKNOWN')],
[
( 3, 4, 'UNKNOWN'),
( 5, 6, 'UNKNOWN')]
]

output = processor.convert_to_NER_format(
Expand Down