diff --git a/medcat/utils/meta_cat/data_utils.py b/medcat/utils/meta_cat/data_utils.py index 83661748..3d043130 100644 --- a/medcat/utils/meta_cat/data_utils.py +++ b/medcat/utils/meta_cat/data_utils.py @@ -180,23 +180,6 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict category_value2id = {} category_values = set([x[2] for x in data]) - # Ensuring that each label has data and checking for class imbalance - - label_data = {key: 0 for key in category_value2id} - for i in range(len(data)): - if data[i][2] in category_value2id: - label_data[data[i][2]] = label_data[data[i][2]] + 1 - - # If a label has no data, changing the mapping - if 0 in label_data.values(): - category_value2id_: Dict = {} - keys_ls = [key for key, value in category_value2id.items() if value != 0] - for k in keys_ls: - category_value2id_[k] = len(category_value2id_) - - logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping: %s",category_value2id_) - category_value2id = category_value2id_ - for c in category_values: if c not in category_value2id: category_value2id[c] = len(category_value2id) diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py index 0ba068d6..a7acf34c 100644 --- a/medcat/utils/meta_cat/ml_utils.py +++ b/medcat/utils/meta_cat/ml_utils.py @@ -66,7 +66,7 @@ class label of the data x = torch.tensor(x, dtype=torch.long).to(device) # cpos = torch.tensor(cpos, dtype=torch.long).to(device) - attention_masks = (x != 0).type(torch.int) + attention_masks = (x != pad_id).type(torch.int) return x, cpos, attention_masks, y @@ -412,10 +412,16 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average) labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])] + labels_present_: set = set(predictions) + labels_present: List[str] = [str(x) for x in labels_present_] + + if len(labels) != len(labels_present): + logger.warning( + "The evaluation dataset does not contain all the labels, some labels are missing. Performance displayed for labels found...") confusion = pd.DataFrame( data=confusion_matrix(y_eval, predictions, ), - columns=["true " + label for label in labels], - index=["predicted " + label for label in labels], + columns=["true " + label for label in labels_present], + index=["predicted " + label for label in labels_present], ) examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}}