-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ColumnNameLabeler Setup #635
Changes from all commits
5f90417
a924f28
8cd9074
e18bb20
5a3209a
dcc1052
f241707
6fc59e2
1d46359
035b1f1
c31b49f
79c2ab0
5565c41
c1f5ca1
ce1dc35
e77ca9e
8a0a686
c85635e
a2d1586
a321037
9db41c7
78c29fc
d63398a
1c442aa
3cdc143
0d7a49e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
class ColumnNameModel(BaseModel, metaclass=AutoSubRegistrationMeta): | ||
"""Class for column name data labeling model.""" | ||
|
||
def __init__(self, parameters=None): | ||
def __init__(self, label_mapping=None, parameters=None): | ||
"""Initialize function for ColumnNameModel. | ||
|
||
:param parameters: Contains all the appropriate parameters for the model. | ||
|
@@ -38,8 +38,10 @@ def __init__(self, parameters=None): | |
parameters.setdefault("true_positive_dict", None) | ||
parameters.setdefault("include_label", True) | ||
parameters.setdefault("negative_threshold_config", None) | ||
parameters.setdefault("positive_threshold_config", None) | ||
|
||
# initialize class | ||
# validate and set parameters | ||
self.set_label_mapping(label_mapping) | ||
self._validate_parameters(parameters) | ||
self._parameters = parameters | ||
|
||
|
@@ -65,10 +67,25 @@ def _validate_parameters(self, parameters): | |
"false_positive_dict", | ||
"include_label", | ||
"negative_threshold_config", | ||
"positive_threshold_config", | ||
] | ||
|
||
list_of_accepted_parameters = optional_parameters + required_parameters | ||
|
||
if parameters["true_positive_dict"]: | ||
label_map_dict_keys = set(self.label_mapping.keys()) | ||
true_positive_unique_labels = set( | ||
parameters["true_positive_dict"][0].values() | ||
) | ||
|
||
# if not a subset that is less than or equal to | ||
# label mapping dict | ||
if true_positive_unique_labels > label_map_dict_keys: | ||
errors.append( | ||
"""`true_positive_dict` must be a subset | ||
of the `label_mapping` values()""" | ||
) | ||
|
||
Comment on lines
+75
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checking that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we would say what is not supposed to be there, but that can always be done later. |
||
for param in parameters: | ||
value = parameters[param] | ||
if ( | ||
|
@@ -109,8 +126,15 @@ def _validate_parameters(self, parameters): | |
param | ||
) | ||
) | ||
elif param == "positive_threshold_config" and ( | ||
value is None or not isinstance(value, int) | ||
): | ||
errors.append( | ||
"`{}` is a required parameter that must be a boolean.".format(param) | ||
) | ||
Comment on lines
+129
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moving this back to |
||
elif param not in list_of_accepted_parameters: | ||
errors.append("`{}` is not an accepted parameter.".format(param)) | ||
|
||
if errors: | ||
raise ValueError("\n".join(errors)) | ||
|
||
|
@@ -217,17 +241,30 @@ def predict( | |
include_label=self._parameters["include_label"], | ||
) | ||
|
||
if show_confidences: | ||
raise NotImplementedError( | ||
"""`show_confidences` parameter is disabled | ||
for Proof of Concept implementation. Confidence | ||
values are enabled by default.""" | ||
) | ||
predictions = np.array([]) | ||
confidences = np.array([]) | ||
|
||
# `data` at this point is either filtered or not filtered | ||
# list of column names on which we are predicting | ||
for iter_value, value in enumerate(data): | ||
|
||
if output[iter_value][0] > self._parameters["positive_threshold_config"]: | ||
predictions = np.append( | ||
predictions, | ||
self._parameters["true_positive_dict"][output[iter_value][1]][ | ||
"label" | ||
], | ||
) | ||
|
||
if show_confidences: | ||
confidences = np.append(confidences, output[iter_value][0]) | ||
Comment on lines
+244
to
+260
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. returning in the same way as other models |
||
|
||
if verbose: | ||
logger.info("compare_positive process complete") | ||
|
||
return output | ||
if show_confidences: | ||
return {"pred": predictions, "conf": confidences} | ||
return {"pred": predictions} | ||
Comment on lines
+265
to
+267
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. implementing same output as other data labelers / models |
||
|
||
@classmethod | ||
def load_from_disk(cls, dirpath): | ||
|
@@ -243,7 +280,12 @@ def load_from_disk(cls, dirpath): | |
with open(model_param_dirpath, "r") as fp: | ||
parameters = json.load(fp) | ||
|
||
loaded_model = cls(parameters) | ||
# load label_mapping | ||
labels_dirpath = os.path.join(dirpath, "label_mapping.json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
with open(labels_dirpath, "r") as fp: | ||
label_mapping = json.load(fp) | ||
|
||
loaded_model = cls(label_mapping, parameters) | ||
return loaded_model | ||
|
||
def save_to_disk(self, dirpath): | ||
|
@@ -260,3 +302,7 @@ def save_to_disk(self, dirpath): | |
model_param_dirpath = os.path.join(dirpath, "model_parameters.json") | ||
with open(model_param_dirpath, "w") as fp: | ||
json.dump(self._parameters, fp) | ||
|
||
labels_dirpath = os.path.join(dirpath, "label_mapping.json") | ||
with open(labels_dirpath, "w") as fp: | ||
json.dump(self.label_mapping, fp) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2086,15 +2086,9 @@ class ColumnNameModelPostprocessor( | |
): | ||
"""Subclass of BaseDataPostprocessor for postprocessing regex data.""" | ||
|
||
def __init__(self, true_positive_dict=None, positive_threshold_config=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removing this from the post processor to the model file to allow for NOTE: this basically becomes a |
||
def __init__(self): | ||
"""Initialize the ColumnNameModelPostProcessor class.""" | ||
if true_positive_dict is None: | ||
true_positive_dict = {} | ||
|
||
super().__init__( | ||
true_positive_dict=true_positive_dict, | ||
positive_threshold_config=positive_threshold_config, | ||
) | ||
super().__init__() | ||
|
||
def _validate_parameters(self, parameters): | ||
""" | ||
|
@@ -2118,31 +2112,9 @@ def _validate_parameters(self, parameters): | |
errors = [] | ||
|
||
for param in parameters: | ||
value = parameters[param] | ||
|
||
if param == "true_positive_dict" and ( | ||
not isinstance(value, list) | ||
or not isinstance(value[0], dict) | ||
or "attribute" not in value[0].keys() | ||
or "label" not in value[0].keys() | ||
): | ||
errors.append( | ||
"""`{}` is a required parameter that must be a list | ||
of dictionaries each with the following | ||
two keys: 'attribute' and 'label'""".format( | ||
param | ||
) | ||
) | ||
elif param == "positive_threshold_config" and ( | ||
not isinstance(value, int) or value is None | ||
): | ||
errors.append( | ||
"`{}` is an required parameter that must be an integer.".format( | ||
param | ||
) | ||
) | ||
elif param not in allowed_parameters: | ||
if param not in allowed_parameters: | ||
errors.append("`{}` is not a permited parameter.".format(param)) | ||
|
||
if errors: | ||
raise ValueError("\n".join(errors)) | ||
|
||
|
@@ -2173,16 +2145,4 @@ def help(cls): | |
|
||
def process(self, data, labels=None, label_mapping=None, batch_size=None): | ||
"""Preprocess data.""" | ||
results = {} | ||
for iter_value, value in enumerate(data): | ||
if data[iter_value][0] > self._parameters["positive_threshold_config"]: | ||
results[iter_value] = {} | ||
try: | ||
results[iter_value]["pred"] = self._parameters[ | ||
"true_positive_dict" | ||
][data[iter_value][1]]["label"] | ||
except IndexError: | ||
pass | ||
results[iter_value]["conf"] = data[iter_value][0] | ||
|
||
return results | ||
return labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of these, could call the super?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can't bc regex didn't? or an oversight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing some of the other models use the
BaseModel.__init__
but when I do that or the super unit tests are failing... I'll trouble shoot a bit more thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not seeing success with doing
BaseModel.__init__(self, label_mapping, parameters)
... maybe I'm doing something wrong, thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, what were the errors?