Skip to content

Commit

Permalink
CU-2tuwdjf Move to an identifier based config (CogStack#257)
Browse files Browse the repository at this point in the history
* Initial commit with new identifier based config

* Fixed issue with saving into json for idconfig

* Added support for MetaCat and NER configs with the new identifier based config system as well as a very short test case for new config

* New identity based config now allows for creation of new attributes

* Now allowing extra attributes for each of the sub-types of the config

* Moved the comments within configs to docstrings so IDEs would pick them up and show them as/when needed

* Moved identifier based config to regular config package along with tests from a separate file to the main config test file. Removed all remnants of the old config system

* Removed necessity of config merging needing dicts

* Removed most getitem config operations

* Removed unnecessary basemodel checks from new config

* Moved some further comments into docsstrings on the config level

* Added two new simple tests for legacy (getitem) support as well as the new identifier based methodd

* Moved away from Field assignments for config defaults since BaseModel handles new instance initialisation

* Added extra tests to make sure all the config keys are available through both the legacy getitem method as well as through the attribute based method

* Removed unused code from test

* Moved around some things to avoid duplication

* Removed (now) irrelevant comments from config

* Removed unused enums from config

* Added some further docstrings to config class(es)

* Being more accurate regarding setitem return type

* Removed debug from logged output

* Making sure that type validation happens during construction time for configs

* Added alternative parser so that {} could be used as an empty set as well as an ampety dictionary when reading from config

* Now validating config values when assigning as well

* Added new tests for merge and assignment validation

* Removed legacy comments

* Fixed a few type hinting issues that caused mypy to result in errors

* Fixed indentation and blank lines issues within meta cant and transformers ner configs

* Added tests that make sure the config hashes change upon change to values as well as that they are consistent for identical configs

* Added test that ensures version-specific config stuff doesn't affect the hash

* Added pydantic dependency to setup.py

* Moved config logger to class attribute

* Fixed typo in docsting

* Fixed typo in config test

* Removed inheritance from object

* Added tests for config's parse_config_file method
  • Loading branch information
mart-r authored Sep 11, 2022
1 parent 8bf8ed7 commit 6fc9b4b
Show file tree
Hide file tree
Showing 14 changed files with 794 additions and 395 deletions.
104 changes: 52 additions & 52 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def __init__(self,
self._addl_ner = addl_ner if isinstance(addl_ner, list) else [addl_ner]
self._create_pipeline(self.config)

def _create_pipeline(self, config):
def _create_pipeline(self, config: Config):
# Set log level
self.log.setLevel(config.general['log_level'])
self.log.setLevel(config.general.log_level)

# Build the pipeline
self.pipe = Pipe(tokenizer=spacy_split_all, config=config)
Expand All @@ -123,14 +123,14 @@ def _create_pipeline(self, config):

# Add addl_ner if they exist
for ner in self._addl_ner:
self.pipe.add_addl_ner(ner, ner.config.general['name'])
self.pipe.add_addl_ner(ner, ner.config.general.name)

# Add meta_annotation classes if they exist
for meta_cat in self._meta_cats:
self.pipe.add_meta_cat(meta_cat, meta_cat.config.general['category_name'])
self.pipe.add_meta_cat(meta_cat, meta_cat.config.general.category_name)

# Set max document length
self.pipe.spacy_nlp.max_length = config.preprocessing.get('max_document_length', 1000000)
self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length

@deprecated(message="Replaced with cat.pipe.spacy_nlp.")
def get_spacy_nlp(self) -> Language:
Expand Down Expand Up @@ -165,17 +165,17 @@ def get_model_card(self, as_dict=False):
By default a str - indented JSON object.
"""
card = {
'Model ID': self.config.version['id'],
'Last Modified On': self.config.version['last_modified'],
'History (from least to most recent)': self.config.version['history'],
'Description': self.config.version['description'],
'Source Ontology': self.config.version['ontology'],
'Location': self.config.version['location'],
'MetaCAT models': self.config.version['meta_cats'],
'Basic CDB Stats': self.config.version['cdb_info'],
'Performance': self.config.version['performance'],
'Model ID': self.config.version.id,
'Last Modified On': self.config.version.last_modified,
'History (from least to most recent)': self.config.version.history,
'Description': self.config.version.description,
'Source Ontology': self.config.version.ontology,
'Location': self.config.version.location,
'MetaCAT models': self.config.version.meta_cats,
'Basic CDB Stats': self.config.version.cdb_info,
'Performance': self.config.version.performance,
'Important Parameters (Partial view, all available in cat.config)': get_important_config_parameters(self.config),
'MedCAT Version': self.config.version['medcat_version']
'MedCAT Version': self.config.version.medcat_version
}

if as_dict:
Expand All @@ -185,20 +185,20 @@ def get_model_card(self, as_dict=False):

def _versioning(self):
# Check version info and do not allow without it
if self.config.version['description'] == 'No description':
if self.config.version.description == 'No description':
self.log.warning("Please consider populating the version information [description, performance, location, ontology] in cat.config.version")

# Fill the stuff automatically that is needed for versioning
m = self.get_hash()
version = self.config.version
if version['id'] is None or m != version['id']:
if version['id'] is not None:
version['history'].append(version['id'])
version['id'] = m
version['last_modified'] = date.today().strftime("%d %B %Y")
version['cdb_info'] = self.cdb._make_stats()
version['meta_cats'] = [meta_cat.get_model_card(as_dict=True) for meta_cat in self._meta_cats]
version['medcat_version'] = __version__
if version.id is None or m != version.id:
if version.id is not None:
version.history.append(version['id'])
version.id = m
version.last_modified = date.today().strftime("%d %B %Y")
version.cdb_info = self.cdb._make_stats()
version.meta_cats = [meta_cat.get_model_card(as_dict=True) for meta_cat in self._meta_cats]
version.medcat_version = __version__
self.log.warning("Please consider updating [description, performance, location, ontology] in cat.config.version")

def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_MODEL_PACK_NAME) -> str:
Expand All @@ -211,10 +211,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
Model pack name
"""
# Spacy model always should be just the name, but during loading it can be reset to path
self.config.general['spacy_model'] = os.path.basename(self.config.general['spacy_model'])
self.config.general.spacy_model = os.path.basename(self.config.general.spacy_model)
# Versioning
self._versioning()
model_pack_name += "_{}".format(self.config.version['id'])
model_pack_name += "_{}".format(self.config.version.id)

self.log.warning("This will save all models into a zip file, can take some time and require quite a bit of disk space.")
_save_dir_path = save_dir_path
Expand All @@ -224,7 +224,7 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
os.makedirs(os.path.expanduser(save_dir_path), exist_ok=True)

# Save the used spacy model
spacy_path = os.path.join(save_dir_path, self.config.general['spacy_model'])
spacy_path = os.path.join(save_dir_path, self.config.general.spacy_model)
if str(self.pipe.spacy_nlp._path) != spacy_path:
# First remove if something is there
shutil.rmtree(spacy_path, ignore_errors=True)
Expand All @@ -243,7 +243,7 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
# Save addl_ner
for comp in self.pipe.spacy_nlp.components:
if isinstance(comp[1], TransformersNER):
trf_path = os.path.join(save_dir_path, "trf_" + comp[1].config.general['name'])
trf_path = os.path.join(save_dir_path, "trf_" + comp[1].config.general.name)
comp[1].save(trf_path)

# Save all meta_cats
Expand Down Expand Up @@ -298,7 +298,7 @@ def load_model_pack(cls, zip_path: str, meta_cat_config_dict: Optional[Dict] = N
# TODO load addl_ner

# Modify the config to contain full path to spacy model
cdb.config.general['spacy_model'] = os.path.join(model_pack_path, os.path.basename(cdb.config.general['spacy_model']))
cdb.config.general.spacy_model = os.path.join(model_pack_path, os.path.basename(cdb.config.general.spacy_model))

# Load Vocab
vocab_path = os.path.join(model_pack_path, "vocab.dat")
Expand Down Expand Up @@ -343,7 +343,7 @@ def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]
"""
# Should we train - do not use this for training, unless you know what you are doing. Use the
#self.train() function
self.config.linking['train'] = do_train
self.config.linking.train = do_train

if text is None:
self.log.error("The input text should be either a string or a sequence of strings but got %s", type(text))
Expand Down Expand Up @@ -423,8 +423,8 @@ def _print_stats(self,
fp_docs: Set = set()
fn_docs: Set = set()
# reset and back up filters
_filters = deepcopy(self.config.linking['filters'])
filters = self.config.linking['filters']
_filters = deepcopy(self.config.linking.filters)
filters = self.config.linking.filters
for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False):
filters['cuis'] = set()

Expand Down Expand Up @@ -583,21 +583,21 @@ def _print_stats(self,
traceback.print_exc()

# restore filters to original state
self.config.linking['filters'] = _filters
self.config.linking.filters = _filters

return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general['checkpoint']['steps'] is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**self.config.general.get('checkpoint', {}))
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.dict())
checkpoint_manager = CheckpointManager('cat_train', checkpoint_config)
if is_resumed:
# TODO: probably remove is_resumed mark and always resume if a checkpoint is provided,
#but I'll leave it for now
checkpoint = checkpoint or checkpoint_manager.get_latest_checkpoint()
self.log.info(f"Resume training on the most recent checkpoint at {checkpoint.dir_path}...")
self.cdb = checkpoint.restore_latest_cdb()
self.cdb.config.merge_config(self.config.__dict__)
self.cdb.config.merge_config(self.config.asdict())
self.config = self.cdb.config
self._create_pipeline(self.config)
else:
Expand Down Expand Up @@ -657,7 +657,7 @@ def train(self,
if checkpoint is not None and checkpoint.steps is not None and latest_trained_step % checkpoint.steps == 0:
checkpoint.save(cdb=self.cdb, count=latest_trained_step)

self.config.linking['train'] = False
self.config.linking.train = False

def add_cui_to_group(self, cui: str, group_name: str) -> None:
r"""
Expand Down Expand Up @@ -848,8 +848,8 @@ def train_supervised(self,
checkpoint = self._init_ckpts(is_resumed, checkpoint)

# Backup filters
_filters = deepcopy(self.config.linking['filters'])
filters = self.config.linking['filters']
_filters = deepcopy(self.config.linking.filters)
filters = self.config.linking.filters

fp = fn = tp = p = r = f1 = examples = {}
with open(data_path) as f:
Expand Down Expand Up @@ -967,7 +967,7 @@ def train_supervised(self,
extra_cui_filter=extra_cui_filter)

# Set the filters again
self.config.linking['filters'] = _filters
self.config.linking.filters = _filters

return fp, fn, tp, p, r, f1, cui_counts, examples

Expand Down Expand Up @@ -1015,8 +1015,8 @@ def get_entities_multi_texts(self,
elif out[i].get('text', '') != text:
out.insert(i, self._doc_to_out(None, only_cui, addl_info))

cnf_annotation_output = getattr(self.config, 'annotation_output', {})
if not(cnf_annotation_output.get('include_text_in_output', False)):
cnf_annotation_output = self.config.annotation_output
if not(cnf_annotation_output.include_text_in_output):
for o in out:
if o is not None:
o.pop('text', None)
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def multiprocessing(self,
raise Exception("Please do not use multiprocessing when running a transformer model for NER, run sequentially.")

# Set max document length
self.pipe.spacy_nlp.max_length = self.config.preprocessing.get('max_document_length', 1000000)
self.pipe.spacy_nlp.max_length = self.config.preprocessing.max_document_length

if self._meta_cats and not separate_nn_components:
# Hack for torch using multithreading, which is not good if not
Expand Down Expand Up @@ -1388,10 +1388,10 @@ def _doc_to_out(self,
addl_info: List[str],
out_with_text: bool = False) -> Dict:
out: Dict = {'entities': {}, 'tokens': []}
cnf_annotation_output = getattr(self.config, 'annotation_output', {})
cnf_annotation_output = self.config.annotation_output
if doc is not None:
out_ent: Dict = {}
if self.config.general.get('show_nested_entities', False):
if self.config.general.show_nested_entities:
_ents = []
for _ent in doc._.ents:
entity = Span(doc, _ent['start'], _ent['end'], label=_ent['label'])
Expand All @@ -1405,18 +1405,18 @@ def _doc_to_out(self,
else:
_ents = doc.ents

if cnf_annotation_output.get("lowercase_context", True):
if cnf_annotation_output.lowercase_context:
doc_tokens = [tkn.text_with_ws.lower() for tkn in list(doc)]
else:
doc_tokens = [tkn.text_with_ws for tkn in list(doc)]

if cnf_annotation_output.get('doc_extended_info', False):
if cnf_annotation_output.doc_extended_info:
# Add tokens if extended info
out['tokens'] = doc_tokens

context_left = cnf_annotation_output.get('context_left', -1)
context_right = cnf_annotation_output.get('context_right', -1)
doc_extended_info = cnf_annotation_output.get('doc_extended_info', False)
context_left = cnf_annotation_output.context_left
context_right = cnf_annotation_output.context_right
doc_extended_info = cnf_annotation_output.doc_extended_info

for _, ent in enumerate(_ents):
cui = str(ent._.cui)
Expand Down Expand Up @@ -1453,12 +1453,12 @@ def _doc_to_out(self,
else:
out['entities'][ent._.id] = cui

if cnf_annotation_output.get('include_text_in_output', False) or out_with_text:
if cnf_annotation_output.include_text_in_output or out_with_text:
out['text'] = doc.text
return out

def _get_trimmed_text(self, text: Optional[str]) -> str:
return text[0:self.config.preprocessing.get('max_document_length')] if text is not None and len(text) > 0 else ""
return text[0:self.config.preprocessing.max_document_length] if text is not None and len(text) > 0 else ""

def _generate_trimmed_texts(self, texts: Union[Iterable[str], Iterable[Tuple]]) -> Iterable[str]:
text_: str
Expand Down
15 changes: 8 additions & 7 deletions medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def save(self, path: str) -> None:
with open(path, 'wb') as f:
# No idea how to this correctly
to_save = {}
to_save['config'] = self.config.__dict__
to_save['config'] = self.config.asdict()
to_save['cdb'] = {k:v for k,v in self.__dict__.items() if k != 'config'}
dill.dump(to_save, f)

Expand Down Expand Up @@ -664,14 +664,15 @@ def most_similar(self,
@staticmethod
def _ensure_backward_compatibility(config: Config) -> None:
# Hacky way of supporting old CDBs
weighted_average_function = config.linking['weighted_average_function']
weighted_average_function = config.linking.weighted_average_function
if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "<lambda>":
config.linking['weighted_average_function'] = partial(weighted_average, factor=0.0004)
if config.general.get('workers', None) is None:
config.general['workers'] = workers()
disabled_comps = config.general.get('spacy_disabled_components', [])
# the following type ignoring is for mypy because it is unable to detect the signature
config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore
if config.general.workers is None:
config.general.workers = workers()
disabled_comps = config.general.spacy_disabled_components
if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
config.general['spacy_disabled_components'].append('lemmatizer')
config.general.spacy_disabled_components.append('lemmatizer')

@classmethod
def _check_medcat_version(cls, config_data: Dict) -> None:
Expand Down
Loading

0 comments on commit 6fc9b4b

Please sign in to comment.