Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU2e77a5x - Add a CDB merge function #373

Merged
merged 18 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions medcat/utils/cdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
import numpy as np

from copy import deepcopy
from medcat.cdb import CDB

logger = logging.getLogger(__name__) # separate logger from the package-level one


class cdb_utils(object):
adam-sutton-1992 marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def merge_cdb(cdb1: "CDB",
cdb2: "CDB",
overwrite_training: int = 0,
full_build: bool = False):
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`

Args:
cdb1 (medcat.cdb.CDB):
The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
cui2preferred_name), this cdb values will be prioritised over cdb2.
cdb2 (medcat.cdb.CDB):
The second medcat cdb to merge.
overwrite_training (int):
Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
full_build (bool):
Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
"""
config = deepcopy(cdb1.config)
cdb = CDB(config)

# Copy CDB 1 - as all settings from CDB 1 will be carried over
cdb.cui2names = deepcopy(cdb1.cui2names)
cdb.cui2snames = deepcopy(cdb1.cui2snames)
cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
cdb.cui2info = deepcopy(cdb1.cui2info)
cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
cdb.cui2tags = deepcopy(cdb1.cui2tags)
cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
cdb.name2cuis = deepcopy(cdb1.name2cuis)
cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
cdb.name2count_train = deepcopy(cdb1.name2count_train)
cdb.name_isupper = deepcopy(cdb1.name_isupper)
if full_build:
cdb.addl_info = deepcopy(cdb1.addl_info)

# handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
for cui in cdb2.cui2names:
names = dict()
for name in cdb2.cui2names[cui]:
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
# For addl_info check cui2original_names as they MUST be added
ontologies = set()
description = ''
to_build = False
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
to_build = True
if 'cui2ontologies' in cdb2.addl_info:
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
if 'cui2description' in cdb2.addl_info:
description = cdb2.addl_info['cui2description'][cui]
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
if cui in cdb1.cui2names:
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
if overwrite_training == 2 and cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
else:
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
weights = [0, 1]
else:
norm = cdb.cui2count_train[cui]
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
for s in contexts:
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
if cui in cdb1.cui2tags:
cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
if cui in cdb1.cui2type_ids:
cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
else:
if cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2names[cui]
if cui in cdb2.cui2info:
cdb.cui2info[cui] = cdb2.cui2info[cui]
if cui in cdb2.cui2context_vectors:
cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
if cui in cdb2.cui2tags:
cdb.cui2tags[cui] = cdb2.cui2tags[cui]
if cui in cdb2.cui2type_ids:
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]

if overwrite_training != 1:
for name in cdb2.name2cuis:
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]

# snames
cdb.snames = cdb1.snames.union(cdb2.snames)

# vocab, adding counts if they occur in both
cdb.vocab = deepcopy(cdb1.vocab)
if overwrite_training != 1:
for word in cdb2.vocab:
if word in cdb.vocab and overwrite_training == 0:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]

return cdb
35 changes: 35 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from medcat.vocab import Vocab
from medcat.cdb_maker import CDBMaker
from medcat.config import Config


class AsyncMock(unittest.mock.MagicMock):
Expand Down Expand Up @@ -86,3 +88,36 @@ def check_or_download(self):
return
with open(self.vocab_path, 'wb') as f:
f.write(tmp.content)


class ForCDBMerging:

def __init__(self) -> None:
# generating cdbs - two maker are requested as they point to the same created CDB.
config = Config()
config.general["spacy_model"] = "en_core_web_md"
maker1 = CDBMaker(config)
maker2 = CDBMaker(config) # second maker is required as it will otherwise point to same object
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "model_creator", "umls_sample.csv")
self.cdb1 = maker1.prepare_csvs(csv_paths=[path])
self.cdb2 = maker2.prepare_csvs(csv_paths=[path])

# generating context vectors here for for testing the weighted average function (based off cui2count_train)
zeroes = np.zeros(shape=(1,300))
ones = np.ones(shape=(1,300))
for i, cui in enumerate(self.cdb1.cui2names):
self.cdb1.cui2context_vectors[cui] = {"short": ones}
self.cdb2.cui2context_vectors[cui] = {"short": zeroes}
self.cdb1.cui2count_train[cui] = 1
self.cdb2.cui2count_train[cui] = i + 1
# adding new names and cuis to each cdb to test after merging
test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
self.cdb1.add_names("C0006826", test_add)
unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}}
self.cdb2.add_names("UniqueTest", unique_test)
self.cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes}
self.cdb2.addl_info["cui2ontologies"] = {}
self.cdb2.addl_info["cui2description"] = {}
for cui in self.cdb2.cui2names:
self.cdb2.addl_info["cui2ontologies"][cui] = {"test_ontology"}
self.cdb2.addl_info["cui2description"][cui] = "test_description"
1 change: 1 addition & 0 deletions tests/test_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ def test_cui2snames_population(self):
with self.subTest(cui):
self.assertIn(cui, self.undertest.cui2snames)


adam-sutton-1992 marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
unittest.main()
42 changes: 42 additions & 0 deletions tests/utils/test_cdb_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
import numpy as np
from tests.helper import ForCDBMerging
from medcat.utils.cdb_utils import cdb_utils


class CDBMergeTests(unittest.TestCase):
@classmethod
def setUp(cls) -> None:
adam-sutton-1992 marked this conversation as resolved.
Show resolved Hide resolved
to_merge = ForCDBMerging()
cls.cdb1 = to_merge.cdb1
cls.cdb2 = to_merge.cdb2
cls.merged_cdb = cdb_utils.merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2)
cls.overwrite_cdb = cdb_utils.merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True)
cls.zeroes = np.zeros(shape=(1,300))
cls.ones = np.ones(shape=(1,300))

def test_merge_inserts(self):
self.assertIn("test", self.merged_cdb.cui2names["C0006826"])
self.assertIn("test_name", self.merged_cdb.cui2snames["C0006826"])
self.assertEqual("Cancer", self.merged_cdb.cui2preferred_name["C0006826"])

def test_no_full_build(self):
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())
self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict())

def test_full_build(self):
for cui in self.cdb2.cui2names:
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")

def test_vector_merge(self):
self.assertTrue(np.array_equal(self.zeroes, self.merged_cdb.cui2context_vectors["UniqueTest"]["short"]))
for i, cui in enumerate(self.cdb1.cui2names):
self.assertTrue(np.array_equal(self.merged_cdb.cui2context_vectors[cui]["short"], np.divide(self.ones, i+2)))


def test_overwrite_parameter(self):
for cui in self.cdb2.cui2names:
self.assertTrue(np.array_equal(self.overwrite_cdb.cui2context_vectors[cui]["short"], self.zeroes))
self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"})
self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description")