-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ColumnNameLabeler Setup #635
Changes from 12 commits
5f90417
a924f28
8cd9074
e18bb20
5a3209a
dcc1052
f241707
6fc59e2
1d46359
035b1f1
c31b49f
79c2ab0
5565c41
c1f5ca1
ce1dc35
e77ca9e
8a0a686
c85635e
a2d1586
a321037
9db41c7
78c29fc
d63398a
1c442aa
3cdc143
0d7a49e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
class ColumnNameModel(BaseModel, metaclass=AutoSubRegistrationMeta): | ||
"""Class for column name data labeling model.""" | ||
|
||
def __init__(self, parameters=None): | ||
def __init__(self, label_mapping=None, parameters=None): | ||
"""Initialize function for ColumnNameModel. | ||
|
||
:param parameters: Contains all the appropriate parameters for the model. | ||
|
@@ -39,7 +39,8 @@ def __init__(self, parameters=None): | |
parameters.setdefault("include_label", True) | ||
parameters.setdefault("negative_threshold_config", None) | ||
|
||
# initialize class | ||
# validate and set parameters | ||
self.set_label_mapping(label_mapping) | ||
self._validate_parameters(parameters) | ||
self._parameters = parameters | ||
|
||
|
@@ -220,8 +221,8 @@ def predict( | |
if show_confidences: | ||
raise NotImplementedError( | ||
"""`show_confidences` parameter is disabled | ||
for Proof of Concept implementation. Confidence | ||
values are enabled by default.""" | ||
for MVP implementation. Note: Confidence | ||
values are returned by default.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thought is we look at this in a re-work of the pipeline. Will require some thought how the data is passed into the labeler flow and re formatting of the |
||
) | ||
|
||
if verbose: | ||
|
@@ -243,7 +244,12 @@ def load_from_disk(cls, dirpath): | |
with open(model_param_dirpath, "r") as fp: | ||
parameters = json.load(fp) | ||
|
||
loaded_model = cls(parameters) | ||
# load label_mapping | ||
labels_dirpath = os.path.join(dirpath, "label_mapping.json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
with open(labels_dirpath, "r") as fp: | ||
label_mapping = json.load(fp) | ||
|
||
loaded_model = cls(label_mapping, parameters) | ||
return loaded_model | ||
|
||
def save_to_disk(self, dirpath): | ||
|
@@ -260,3 +266,7 @@ def save_to_disk(self, dirpath): | |
model_param_dirpath = os.path.join(dirpath, "model_parameters.json") | ||
with open(model_param_dirpath, "w") as fp: | ||
json.dump(self._parameters, fp) | ||
|
||
labels_dirpath = os.path.join(dirpath, "label_mapping.json") | ||
with open(labels_dirpath, "w") as fp: | ||
json.dump(self.label_mapping, fp) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2175,14 +2175,16 @@ def process(self, data, labels=None, label_mapping=None, batch_size=None): | |
"""Preprocess data.""" | ||
results = {} | ||
for iter_value, value in enumerate(data): | ||
if data[iter_value][0] > self._parameters["positive_threshold_config"]: | ||
results[iter_value] = {} | ||
try: | ||
try: | ||
if ( | ||
labels[iter_value][0] | ||
> self._parameters["positive_threshold_config"] | ||
): | ||
results[iter_value] = {} | ||
results[iter_value]["pred"] = self._parameters[ | ||
"true_positive_dict" | ||
][data[iter_value][1]]["label"] | ||
except IndexError: | ||
pass | ||
results[iter_value]["conf"] = data[iter_value][0] | ||
|
||
][labels[iter_value][1]]["label"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo fix after realizing |
||
results[iter_value]["conf"] = data[iter_value][0] | ||
except IndexError: | ||
pass | ||
return results |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,10 +38,14 @@ | |
"include_label": True, | ||
} | ||
|
||
mock_label_mapping = {"ssn": 1, "name": 2, "address": 3} | ||
|
||
|
||
def mock_open(filename, *args): | ||
if filename.find("model_parameters") >= 0: | ||
return StringIO(json.dumps(mock_model_parameters)) | ||
elif filename.find("label_mapping") >= 0: | ||
return StringIO(json.dumps(mock_label_mapping)) | ||
|
||
|
||
def setup_save_mock_open(mock_open): | ||
|
@@ -52,16 +56,12 @@ def setup_save_mock_open(mock_open): | |
|
||
|
||
class TestColumnNameModel(unittest.TestCase): | ||
def setUp(self): | ||
@classmethod | ||
def setUp(cls): | ||
# data | ||
data = [ | ||
"ssn", | ||
"role_name", | ||
"wallet_address", | ||
] | ||
cls.data = ["ssn", "role_name", "wallet_address"] | ||
|
||
def test_param_validation(self): | ||
invalid_parameters = [ | ||
cls.invalid_parameters = [ | ||
{ | ||
"false_positive_dict": [ | ||
{ | ||
|
@@ -79,12 +79,44 @@ def test_param_validation(self): | |
}, | ||
] | ||
|
||
model = ColumnNameModel(parameters=mock_model_parameters) | ||
cls.parameters = { | ||
"true_positive_dict": [ | ||
{"attribute": "ssn", "label": "ssn"}, | ||
{"attribute": "suffix", "label": "name"}, | ||
{"attribute": "my_home_address", "label": "address"}, | ||
], | ||
"false_positive_dict": [ | ||
{ | ||
"attribute": "contract_number", | ||
"label": "ssn", | ||
}, | ||
{ | ||
"attribute": "role", | ||
"label": "name", | ||
}, | ||
{ | ||
"attribute": "send_address", | ||
"label": "address", | ||
}, | ||
], | ||
"negative_threshold_config": 50, | ||
"include_label": True, | ||
} | ||
|
||
cls.test_label_mapping = {"ssn": 1, "name": 2, "address": 3} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason this over: |
||
|
||
def test_param_validation(self): | ||
|
||
model = ColumnNameModel( | ||
label_mapping=self.test_label_mapping, parameters=mock_model_parameters | ||
) | ||
self.assertDictEqual(mock_model_parameters, model._parameters) | ||
|
||
for invalid_param_set in invalid_parameters: | ||
for invalid_param_set in self.invalid_parameters: | ||
with self.assertRaises(ValueError): | ||
ColumnNameModel(parameters=invalid_param_set) | ||
ColumnNameModel( | ||
label_mapping=self.test_label_mapping, parameters=invalid_param_set | ||
) | ||
|
||
@mock.patch("sys.stdout", new_callable=StringIO) | ||
def test_help(self, mock_stdout): | ||
|
@@ -95,21 +127,21 @@ def test_help(self, mock_stdout): | |
@mock.patch("sys.stdout", new_callable=StringIO) | ||
def test_predict(self, mock_stdout): | ||
# test show confidences | ||
model = ColumnNameModel(parameters=mock_model_parameters) | ||
model = ColumnNameModel( | ||
label_mapping=self.test_label_mapping, parameters=mock_model_parameters | ||
) | ||
expected_output = [[100.0, 0]] | ||
with self.assertLogs( | ||
"DataProfiler.labelers.column_name_model", level="INFO" | ||
) as logs: | ||
model_output = model.predict(data=["ssn", "role_name", "wallet_address"]) | ||
model_output = model.predict(data=self.data) | ||
self.assertTrue(np.array_equal(expected_output, model_output)) | ||
self.assertTrue(len(logs.output)) | ||
|
||
# `show_confidences` is disabled currently | ||
# should raise error if set to `True` | ||
with self.assertRaises(NotImplementedError): | ||
model.predict( | ||
data=["ssn", "role_name", "wallet_address"], show_confidences=True | ||
) | ||
model.predict(data=self.data, show_confidences=True) | ||
|
||
# clear stdout | ||
mock_stdout.seek(0) | ||
|
@@ -136,37 +168,24 @@ def test_save(self, mock_open, *mocks): | |
# setup mock | ||
mock_file = setup_save_mock_open(mock_open) | ||
|
||
# Save and load a Model with custom parameters | ||
parameters = { | ||
"true_positive_dict": [ | ||
{"attribute": "ssn", "label": "ssn"}, | ||
{"attribute": "suffix", "label": "name"}, | ||
{"attribute": "my_home_address", "label": "address"}, | ||
], | ||
"false_positive_dict": [ | ||
{ | ||
"attribute": "contract_number", | ||
"label": "ssn", | ||
}, | ||
{ | ||
"attribute": "role", | ||
"label": "name", | ||
}, | ||
{ | ||
"attribute": "send_address", | ||
"label": "address", | ||
}, | ||
], | ||
"negative_threshold_config": 50, | ||
"include_label": True, | ||
} | ||
|
||
model = ColumnNameModel(parameters) | ||
model = ColumnNameModel( | ||
label_mapping=mock_model_parameters, parameters=self.parameters | ||
) | ||
|
||
model.save_to_disk(".") | ||
self.assertDictEqual( | ||
parameters, | ||
json.loads(mock_file.getvalue()), | ||
self.assertEqual( | ||
'{"true_positive_dict": [{"attribute": "ssn", "label": "ssn"}, ' | ||
'{"attribute": "suffix", "label": "name"}, {"attribute": "my_home_address", ' | ||
'"label": "address"}], "false_positive_dict": [{"attribute": ' | ||
'"contract_number", "label": "ssn"}, {"attribute": "role", ' | ||
'"label": "name"}, {"attribute": "send_address", "label": "address"}], ' | ||
'"negative_threshold_config": 50, "include_label": true}{"true_positive_dict": ' | ||
'[{"attribute": "ssn", "label": "ssn"}, {"attribute": "suffix", ' | ||
'"label": "name"}, {"attribute": "my_home_address", "label": "address"}], ' | ||
'"false_positive_dict": [{"attribute": "contract_number", "label": "ssn"}, ' | ||
'{"attribute": "role", "label": "name"}, {"attribute": "send_address", "label": ' | ||
'"address"}], "negative_threshold_config": 50, "include_label": true}', | ||
mock_file.getvalue(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test load of model (label mapping and model parameters)
micdavis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# close mock | ||
|
@@ -195,6 +214,15 @@ def test_load(self, *mocks): | |
loaded_model._parameters["negative_threshold_config"], | ||
) | ||
|
||
def test_reverse_label_mapping(self): | ||
"""test reverse label mapping is propograting | ||
through the classes correctly""" | ||
reverse_label_mapping = {v: k for k, v in self.test_label_mapping.items()} | ||
model = ColumnNameModel( | ||
label_mapping=self.test_label_mapping, parameters=self.parameters | ||
) | ||
self.assertEqual(model.reverse_label_mapping, reverse_label_mapping) | ||
|
||
def missing_module_test(self, class_name, module_name): | ||
orig_import = __import__ | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. brand new file for testing the pre processor, model, and post processor all together |
||
import unittest | ||
|
||
import numpy as np | ||
import pkg_resources | ||
|
||
from dataprofiler.labelers.column_name_model import ColumnNameModel | ||
from dataprofiler.labelers.data_labelers import BaseDataLabeler | ||
from dataprofiler.labelers.data_processing import ( | ||
ColumnNameModelPostprocessor, | ||
DirectPassPreprocessor, | ||
) | ||
|
||
default_labeler_dir = pkg_resources.resource_filename("resources", "labelers") | ||
|
||
|
||
class TestColumnNameDataLabeler(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
cls.one_data = ["ssn"] | ||
cls.two_data = ["ssn", "failing_fail_fail"] | ||
|
||
cls.parameters = { | ||
"true_positive_dict": [ | ||
{"attribute": "ssn", "label": "ssn"}, | ||
{"attribute": "suffix", "label": "name"}, | ||
{"attribute": "my_home_address", "label": "address"}, | ||
], | ||
"false_positive_dict": [ | ||
{ | ||
"attribute": "contract_number", | ||
"label": "ssn", | ||
}, | ||
{ | ||
"attribute": "role", | ||
"label": "name", | ||
}, | ||
{ | ||
"attribute": "send_address", | ||
"label": "address", | ||
}, | ||
], | ||
"negative_threshold_config": 50, | ||
"include_label": True, | ||
} | ||
|
||
cls.label_mapping = [ | ||
label["label"] for label in cls.parameters["true_positive_dict"] | ||
] | ||
|
||
preprocessor = DirectPassPreprocessor() | ||
model = ColumnNameModel( | ||
label_mapping=cls.label_mapping, parameters=cls.parameters | ||
) | ||
postprocessor = ColumnNameModelPostprocessor( | ||
true_positive_dict=cls.parameters["true_positive_dict"], | ||
positive_threshold_config=85, | ||
) | ||
|
||
cls.data_labeler = BaseDataLabeler.load_with_components( | ||
preprocessor=preprocessor, model=model, postprocessor=postprocessor | ||
) | ||
|
||
def test_default_model(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we also need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point yeah should test load from library for the model for sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added and actually caught a typo in the resources too -- 👍 |
||
"""simple test of the DataLabeler's predict""" | ||
|
||
# get prediction from labeler | ||
labeler_predictions = self.data_labeler.predict(self.one_data) | ||
|
||
# for now just checking that it's not empty | ||
# and that let of output is the same as len of | ||
# input values for the model to predict | ||
self.assertIsNotNone(labeler_predictions) | ||
self.assertEqual(len(self.one_data), len(labeler_predictions)) | ||
|
||
def test_results_filtering(self): | ||
"""test where false negative doesn't exist | ||
and true positive is filtered | ||
""" | ||
|
||
self.parameters.pop("false_positive_dict") | ||
model = ColumnNameModel( | ||
label_mapping=self.label_mapping, parameters=self.parameters | ||
) | ||
|
||
labeler_predictions = self.data_labeler.predict(self.two_data) | ||
|
||
self.assertIsNotNone(labeler_predictions) | ||
self.assertEqual(1, len(labeler_predictions)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"model": {"class": "ColumnNameModel"}, "preprocessor": {"class": "DirectPassPreprocessor"}, "postprocessor": {"class": "ColumnNamePostprocessor"}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these files look good There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dope There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo in here actually I found when |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"ssn": 1, "name": 2, "address": 3} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"true_positive_dict": [{"attribute": "ssn", "label": "ssn"}, {"attribute": "suffix", "label": "name"}, {"attribute": "my_home_address", "label": "address"}], "false_positive_dict": [{"attribute": "contract_number", "label": "ssn"}, {"attribute": "role", "label": "name"}, {"attribute": "send_address", "label": "address"}], "negative_threshold_config": 50, "include_label": true} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"true_positive_dict": [{"attribute": "ssn", "label": "ssn"}, {"attribute": "suffix", "label": "name"}, {"attribute": "my_home_address", "label": "address"}], "positive_threshold_config": 85} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DirectPassPreprocessor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of these, could call the super?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can't bc regex didn't? or an oversight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing some of the other models use the
BaseModel.__init__
but when I do that or the super unit tests are failing... I'll trouble shoot a bit more thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not seeing success with doing
BaseModel.__init__(self, label_mapping, parameters)
... maybe I'm doing something wrong, thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, what were the errors?