From ffd22a620281b649653e7e481c51e662b495c563 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Wed, 19 Jun 2024 17:01:16 +0100 Subject: [PATCH] v1.12.0 release PR (#455) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Pushing changes for bert-style models for MetaCAT * Pushing fix for LSTM * Pushing changes for flake8 and type fixes * Pushing type fixes * Fixing type issue * Pushing changes 1) Added model.zero_grad to clear accumulated gradients 2) Fixed config save issue 3) Re-structured data preparation for oversampled data * Pushing change and type fixes Pushing ml_utils file which was missed in the last commit * Fixing flake8 issues * Pushing flake8 fixes * Pushing fixes for flake8 * Pushing flake8 fix * Adding peft to list of libraries * Pushing changes with load and train workflow and type fixes The workflow for inference is: load() and inference For training: init() and train() Train will always not load the model dict, except when the phase_number is set to 2 for 2 phase learning's second phase * Pushing changes with type hints and new documentation * Pushing type fix * Fixing type issue * Adding test case for BERT and reverting config changes BERT test cases: Testing for BERT model along with 2 phase learning * Merging changes from master to metacat_bert branch (#431) * Small addition to contribution guidelines (#420) * CU-8694cbcpu: Allow specifying an AU Snomed when preprocessing (#421) * CU-8694dpy1c: Return empty generator upon empty stream (#423) * CU-8694dpy1c: Return empty generator upon empty stream * CU-8694dpy1c: Fix empty generator returns * CU-8694dpy1c: Simplify empty generator returns * Relation extraction (#173) * Added files. * More additions to rel extraction. * Rel base. * Update. * Updates. * Dependency parsing. * Updates. * Added pre-training steps. * Added training & model utils. * Cleanup & fixes. * Update. * Evaluation updates for pretraining. * Removed duplicate relation storage. * Moved RE model file location. * Structure revisions. * Added custom config for RE. * Implemented custom dataset loader for RE. * More changes. * Small fix. * Latest additions to RelCAT (pipe + predictions) * Setup.py fix. * RE utils update. * rel model update. * rel dataset + tokenizer improvements. * RelCAT updates. * RelCAT saving/loading improvements. * RelCAT saving/loading improvements. * RelCAT model fixes. * Attempted gpu learning fix. Dataset label generation fixes. * Minor train dataset gen fix. * Minor train dataset gen fix No.2. * Config updates. * Gpu support fixes. Added label stats. * Evaluation stat fixes. * Cleaned stat output mode during training. * Build fix. * removed unused dependencies and fixed code formatting * Mypy compliance. * Fixed linting. * More Gpu mode train fixes. * Fixed model saving/loading issues when using other baes models. * More fixes to stat evaluation. Added proper CAT integration of RelCAT. * Setup.py typo fix. * RelCAT loading fix. * RelCAT Config changes. * Type fix. Minor additions to RelCAT model. * Type fixes. * Type corrections. * RelCAT update. * Type fixes. * Fixed type issue. * RelCATConfig: added seed param. * Adaptations to the new codebase + type fixes.. * Doc/type fixes. * Fixed input size issue for model. * Fixed issue(s) with model size and config. * RelCAT: updated configs to new style. * RelCAT: removed old refs to logging. * Fixed GPU training + added extra stat print for train set. * Type fixes. * Updated dev requirements. * Linting. * Fixed pin_memory issue when training on CPU. * Updated RelCAT dataset get + default config. * Updated RelDS generator + default config * Linting. * Updated RelDatset + config. * Pushing updates to model Made changes to: 1) Extracting given number of context tokens left and right of the entities 2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them * Fixing formatting * Update rel_dataset.py * Update rel_dataset.py * Update rel_dataset.py * RelCAT: added test resource files. * RelCAT: Fixed model load/checkpointing. * RelCAT: updated to pipe spacy doc call. * RelCAT: added tests. * Fixed lint/type issues & added rel tag to test DS. * Fixed ann id to token issue. * RelCAT: updated test dataset + tests. * RelCAT: updates to requested changes + dataset improvements. * RelCAT: updated docs/logs according to commends. * RelCAT: type fix. * RelCAT: mct export dataset updates. * RelCAT: test updates + requested changes p2. * RelCAT: log for MCT export train. * Updated docs + split train_test & dataset for benchmarks. * type fixes. --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: mart-r * CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases (#424) * CU-8694fae3r: Avoid publishing PyPI release when doing GH pre-releases * CU-8694fae3r: Fix pre-releases tagging * CU-8694fae3r: Allow actions to run on release edit --------- Co-authored-by: Mart Ratas Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com> * Pushing changed tests and removing empty change * Pushing change for logging * Revert "Pushing change for logging" This reverts commit fbcdb704dddda2a36626c4f54ce0672ecfcf6321. * CU-8694hukwm: Document the materialising of generator when multiproce… (#433) * CU-8694hukwm: Document the materialising of generator when multiprocessing and batching for docs * CU-8694hukwm: Add TODO note for where the generator is materialised * CU-8694hukwm: Add warning from large amounts of generator data (10k items) is materialised by the docs size mp method * CU-8694fk90t (almost) only primitive config (#425) * CU-8694fk90r: Move backwards compatibility method from CDB to config utils * CU-8694fk90r: Move weighted_average_function from config to CDB; create necessary backwards compatibility workarounds * CU-8694fk90r: Move usage of weighted_average_function in tests * CU-8694fk90r: Add JSON encode and decoder for re.Pattern * CU-8694fk90r: Rebuild custom decoder if needed * CU-8694fk90r: Add method to detect old style config * CU-8694fk90r: Use regular json serialisation for config; Retain option to read old jsonpickled config * CU-8694fk90r: Add test for config serialisation * CU-8694fk90r: Make sure to fix weighted_average_function upon setting it * CU-8694fk90t: Add missing tests for config utils * CU-8694fk90t: Add tests for better raised exception upon old way of using weighted_average_function * CU-8694fk90t: Fix exception type in an added test * CU-8694fk90t: Add further tests for exception payload * CU-8694fk90t: Add improved exceptions when using old/unsupported value of weighted_average_function in config * CU-8694fk90t: Add typing fix exceptions * CU-8694fk90t: Make custom exception derive from AttributeError to correctly handle hasattr calls * CU-8694gza88: Create codeql.yml (#434) Run CodeQL to identify vulnerabilities. This will run on any push or pull request to `master`, but also runs once every day in case some new vulnerabilities are discovered (or something else changes). * CU-8694mbn03: Remove the web app (#441) * CU-8694n48uw better deprecation (#443) * CU-8694n493m: Add deprecation and removal versions to deprecation decorator * CU-8694n493m: Deprecation version to existing deprecated methods. Made the removal version 2 minor versions from the minor version in which the method was deprecated, or the next minor version if the method had been deprecated for longer. * CU-8694n4ff0: Raise exception upon deprecated method call at test time * CU-8694n4ff0: Fix usage of deprecated methods call during test time * CU-8694pey4u: extract cdb load to cls method, to be used in trainer for model pack loading * CU-8694pey4u: extract meta cat loading also to a cls method * CU-8694pey4u: docstrings * CU-8694pey4u: typehints and mypy issues * CU-8694pey4u: fix flake8 * CU-8694pey4u: fix flake8 * CU-8694pey4u: missing extra config if passed in * CU-8694py1jr: Fix issue with reuse of opened file when loading old configs * CU-8694py1jr: Make old config identifier more robust * CU-8694py1jr: Add doc string to old config identifier * CU-8694py1jr: Add test for old style MetaCAT config load * CU-8694py1jr: Add test for old style main config load (functional) * CU-8694py1jr: Refactor config utils load tests for more flexibility * CU-8694py1jr: Add config utils load tests for NER and Rel CAT configs * CU-8694vcvz7: Trust remote code when loading transfomers NER dataset (#453) * CU-8694vcvz7: Trust remote code when loading transfomers NER dataset * CU-8694vcvz7: Add support for older datasets without the remote code trusing kwarg * CU-8694gzbn3 k fold metrics (#432) * CU-8694gzbud: Add context manager that is able to snapshot CDB state * CU-8694gzbud: Add tests to snapshotting CDB state * CU-8694gzbud: Refactor tests for CDB state snapshotting * CU-8694gzbud: Remove use of deprecated method in CDB utils and use non-deprecated one instead * CU-8694gzbud: Add tests for training and CDB state capturing * CU-8694gzbud: Small refactor in tests * CU-8694gzbud: Add option to save state on disk * CU-8694gzbud: Add debug logging output when saving state on disk * CU-8694gzbud: Remove unused import * CU-8694gzbud: Add tests for disk-based state save * CU-8694gzbud: Move CDB state code to its own module * CU-8694gzbud: Remove unused import * CU-8694gzbud: Add doc strings to methods * CU-8694gzbx4: Small optimisation for stats * CU-8694gzbx4: Add MCTExport related module * CU-8694gzbx4: Add MCTExport related tests * CU-8694gzbx4: Add code for k-fold statistics * CU-8694gzbx4: Add tests for k-fold statistics * CU-8694gzbx4: Add test-MCT export with fake concepts * CU-8694gzbx4: Fix a doc string * CU-8694gzbx4: Fix types in MCT export module * CU-8694gzbx4: Fix types in k-fold module * CU-8694gzbx4: Remove accidentally committed test class * CU-8694gzbn3: Add missing test helper file * CU-8694gzbn3: Remove whitespace change from otherwise uncahnged file * CU-8694gzbn3: Allow 5 minutes longer for tests * CU-8694gzbn3: Move to python 3.8-compatible typed dict * CU-8694gzbn3: Add more time for tests in worklow (now 30 minutes) * CU-8694gzbn3: Add more time for tests in worklow (now 45 minutes) * CU-8694gzbn3: Update test-pypi timeout to 45 minutes * CU-8694gzbn3: Remove timeout from unit tests in main workflow * CU-8694gzbn3: Make tests stop upon first failure * CU-8694gzbn3: Fix test stop upon first failure (arg/option order) * CU-8694gzbn3: Remove debug code and old comments * CU-8694gzbn3: Remove all timeouts from main workflow * CU-8694gzbn3: Remove more old / useless comments in tests * CU-8694gzbn3: Add debug output when running k-fold tests to see where it may be stalling * CU-8694gzbn3: Add debug output when ANY tests to see where it may be stalling * CU-8694gzbn3: Remove explicit debug output from k-fold test cases * CU-8694gzbn3: Remove timeouts from DEID tests in case they're the ones creating issues * GHA/test fixes (#437) * Revert "CU-8694gzbn3: Remove timeouts from DEID tests in case they're the ones creating issues" This reverts commit faaf7fb1c4b8b1a9c81ac6b81a464fdd2b55afdc. * Revert "CU-8694gzbn3: Remove explicit debug output from k-fold test cases" This reverts commit 9b0292517c3f442a57fdc04593f721309f0502f7. * Revert "CU-8694gzbn3: Add debug output when ANY tests to see where it may be stalling" This reverts commit 12c519aceb25b960f4b9ff3d3298d18aff8030b4. * Revert "CU-8694gzbn3: Add debug output when running k-fold tests to see where it may be stalling" This reverts commit 03531da0288eb3949807f6720c169625c4d3097c. * Revert "CU-8694gzbn3: Remove all timeouts from main workflow" This reverts commit e6debce71e053ac825f6c8191e5d0e454a77bb11. * Revert "CU-8694gzbn3: Fix test stop upon first failure (arg/option order)" This reverts commit 666c0139f48b7c1565bd250bf0b780069005623f. * Revert "CU-8694gzbn3: Make tests stop upon first failure" This reverts commit 94bce5650d967b0273e1546326108e50f611f687. * Revert "CU-8694gzbn3: Remove timeout from unit tests in main workflow" This reverts commit 3618b9c7cc5b755f430debf129727fab235d79ce. * CU-8694gzbn3: Improve state copy code in CDB state tests * CU-8694gzbn3: Fix a CDB state test issue * CU-8694gzbn3: Split all tests into 2 halves * CU-8694gzbn3: Remove legacy / archived / unused tests * CU-8694gzbn3: Add doc strings for FoldCreator init * CU-8694gzbn3: Move to a split-type enum * CU-8694gzbn3: Add documentation to split-type enum * CU-8694gzbn3: Create separate fold creators for different types of splitting strategies * CU-8694gzbn3: Resort document order in test time nullification process * CU-8694gzbn3: Add option to count number of annotations in doc for MCT export * CU-8694gzbn3: Add weighted documents based split option along with relevant tests * CU-8694gzbn3: Update default fold creation split type to weighted documents * CU-8694gzbn3: Add test to ensure weighted documents split creates a reasonable number of annotations per split * CU-8693n892x environment/dependency snapshots (#438) * CU-8693n892x: Save environment/dependency snapshot upon model pack creation * CU-8693n892x: Fix typing for env snapshot module * CU-8693n892x: Add test for env file existance in .zip * CU-8693n892x: Add doc strings * CU-8693n892x: Centralise env snapshot file name * CU-8693n892x: Add env snapshot file to exceptions in serialisation tests * CU-8693n892x: Only list direct dependencies * CU-8693n892x: Add test that verifies all direct dependencies are listed in environment * CU-8693n892x: Move requirements to separate file and use that for environment snapshot * CU-8693n892x: Remove unused constants * CU-8693n892x: Allow URL based dependencies when using direct dependencies * CU-8693n892x: Distribute install_requires.txt alongside the package; use correct path in distributed version * CU-8694p8y0k deprecation GHA check (#445) * CU-8694p8y0k: Add check for deprecations (code) * CU-8694p8y0k: Add workflow check for deprecations * CU-8694p8y0k: Fix (hopefully) workflow check for deprecations * CU-8694p8y0k: Add option to remove version prefix when checking deprecation * CU-8694p8y0k: Update deprecation checks with more detail (i.e current/next version). * CU-8694p8y0k: Only run deprecation checking step when merging master into production * CU-8694u3yd2 cleanup name removal (#450) * CU-8694u3yd2: Add logged warning for when using full-unlink * CU-8694u3yd2: Make CDB.remove_names simply expect an iterable of names * CU-8694u3yd2: Improve CDB.remove_names doc string * CU-8694u3yd2: Explicitly pass the keys to CDB.remove_names in CAT.unlink_concept_name * CU-8694u3yd2: Add note regarding state (and order) dependent tests to some CDB maker tests * CU-8694u3yd2: Rename/make protected CDB.remove_names method * CU-8694u3yd2: Create deprecated CDB.remove_names method * CU-8694vte2g 1.12 depr removal (#454) * CU-8694vte2g: Remove CDB.add_concept method * CU-8694vte2g: Remove unused import (deprecated decorator) * CU-8694vte2g: Remove CAT.get_spacy_nlp method * CU-8694vte2g: Remove CAT.train_supervised method * CU-8694vte2g: Remove CAT multiprocessing methods * CU-8694vte2g: Remove MetaCAT.train method * CU-8694vte2g: Remove medcat.utils.ner.helper.deid_text method * CU-8694vte2g: Remove use of deprecated method * CU-8694vte2g: Add back removed deprecation import --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: Vlad Dinu <62345326+vladd-bit@users.noreply.github.com> Co-authored-by: Tom Searle --- .github/workflows/codeql.yml | 95 ++++ .github/workflows/main.yml | 29 +- install_requires.txt | 24 + medcat/cat.py | 173 ++++--- medcat/cdb.py | 91 ++-- medcat/config.py | 57 ++- medcat/config_meta_cat.py | 19 +- medcat/linking/vector_context_model.py | 4 +- medcat/meta_cat.py | 135 ++++-- medcat/ner/transformers_ner.py | 29 +- medcat/stats/kfold.py | 436 ++++++++++++++++++ medcat/stats/mctexport.py | 66 +++ medcat/stats/stats.py | 6 +- medcat/tokenizers/meta_cat_tokenizers.py | 13 +- medcat/utils/cdb_state.py | 179 +++++++ medcat/utils/cdb_utils.py | 2 +- medcat/utils/config_utils.py | 66 ++- medcat/utils/decorators.py | 20 +- medcat/utils/meta_cat/data_utils.py | 169 +++++-- medcat/utils/meta_cat/ml_utils.py | 206 +++++++-- medcat/utils/meta_cat/models.py | 174 ++++--- medcat/utils/ner/__init__.py | 2 +- medcat/utils/ner/deid.py | 12 +- medcat/utils/ner/helpers.py | 36 -- medcat/utils/saving/coding.py | 32 +- medcat/utils/saving/envsnapshot.py | 73 +++ setup.py | 38 +- tests/__init__.py | 25 + tests/archive_tests/test_cdb_maker_archive.py | 124 ----- tests/archive_tests/test_ner_archive.py | 139 ------ tests/check_deprecations.py | 178 +++++++ tests/medmentions/make_cdb.py | 120 ----- tests/medmentions/prepare_data.py | 7 - tests/resources/jsonpickle_config.json | 274 +++++++++++ .../resources/jsonpickle_meta_cat_config.json | 89 ++++ .../resources/jsonpickle_rel_cat_config.json | 91 ++++ tests/resources/jsonpickle_tner_config.json | 23 + .../medcat_trainer_export_FAKE_CONCEPTS.json | 84 ++++ .../webapp/demo => tests/stats}/__init__.py | 0 tests/stats/helpers.py | 17 + tests/stats/test_kfold.py | 298 ++++++++++++ tests/stats/test_mctexport.py | 38 ++ tests/test_cat.py | 28 +- tests/test_cdb_maker.py | 20 +- tests/test_config.py | 39 ++ tests/test_meta_cat.py | 50 +- tests/utils/saving/test_envsnapshot.py | 105 +++++ tests/utils/saving/test_serialization.py | 8 +- tests/utils/test_cdb_state.py | 113 +++++ tests/utils/test_config_utils.py | 121 +++++ webapp/.gitignore | 6 - webapp/README.md | 1 - webapp/docker-compose.yml | 26 -- webapp/envs/env_db_backup | 8 - webapp/envs/env_medmen | 1 - webapp/webapp/.dockerignore | 2 - webapp/webapp/Dockerfile | 37 -- webapp/webapp/data/.keep | 0 webapp/webapp/db/.keep | 0 webapp/webapp/demo/admin.py | 16 - webapp/webapp/demo/apps.py | 5 - webapp/webapp/demo/db_backup.py | 20 - webapp/webapp/demo/forms.py | 48 -- webapp/webapp/demo/migrations/0001_initial.py | 22 - .../migrations/0002_downloader_medcatmodel.py | 38 -- webapp/webapp/demo/migrations/__init__.py | 0 webapp/webapp/demo/models.py | 31 -- webapp/webapp/demo/static/css/annotations.css | 110 ----- webapp/webapp/demo/static/css/base.css | 86 ---- webapp/webapp/demo/static/css/home.css | 23 - webapp/webapp/demo/static/image/favicon.ico | Bin 4641 -> 0 bytes webapp/webapp/demo/static/js/.keep | 0 webapp/webapp/demo/static/js/anns.js | 95 ---- webapp/webapp/demo/templates/base.html | 33 -- .../demo/templates/train_annotations.html | 147 ------ .../demo/templates/umls_user_validation.html | 67 --- webapp/webapp/demo/tests.py | 3 - webapp/webapp/demo/urls.py | 9 - webapp/webapp/demo/views.py | 129 ------ webapp/webapp/etc/cron.d/db-backup-cron | 1 - webapp/webapp/manage.py | 21 - webapp/webapp/models/.keep | 0 webapp/webapp/requirements.txt | 6 - webapp/webapp/webapp/__init__.py | 0 webapp/webapp/webapp/settings.py | 146 ------ webapp/webapp/webapp/urls.py | 26 -- webapp/webapp/webapp/wsgi.py | 16 - 87 files changed, 3316 insertions(+), 2040 deletions(-) create mode 100644 .github/workflows/codeql.yml create mode 100644 install_requires.txt create mode 100644 medcat/stats/kfold.py create mode 100644 medcat/stats/mctexport.py create mode 100644 medcat/utils/cdb_state.py create mode 100644 medcat/utils/saving/envsnapshot.py delete mode 100644 tests/archive_tests/test_cdb_maker_archive.py delete mode 100644 tests/archive_tests/test_ner_archive.py create mode 100644 tests/check_deprecations.py delete mode 100644 tests/medmentions/make_cdb.py delete mode 100644 tests/medmentions/prepare_data.py create mode 100644 tests/resources/jsonpickle_config.json create mode 100644 tests/resources/jsonpickle_meta_cat_config.json create mode 100644 tests/resources/jsonpickle_rel_cat_config.json create mode 100644 tests/resources/jsonpickle_tner_config.json create mode 100644 tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json rename {webapp/webapp/demo => tests/stats}/__init__.py (100%) create mode 100644 tests/stats/helpers.py create mode 100644 tests/stats/test_kfold.py create mode 100644 tests/stats/test_mctexport.py create mode 100644 tests/utils/saving/test_envsnapshot.py create mode 100644 tests/utils/test_cdb_state.py create mode 100644 tests/utils/test_config_utils.py delete mode 100644 webapp/.gitignore delete mode 100644 webapp/README.md delete mode 100644 webapp/docker-compose.yml delete mode 100644 webapp/envs/env_db_backup delete mode 100644 webapp/envs/env_medmen delete mode 100644 webapp/webapp/.dockerignore delete mode 100644 webapp/webapp/Dockerfile delete mode 100644 webapp/webapp/data/.keep delete mode 100644 webapp/webapp/db/.keep delete mode 100644 webapp/webapp/demo/admin.py delete mode 100644 webapp/webapp/demo/apps.py delete mode 100644 webapp/webapp/demo/db_backup.py delete mode 100644 webapp/webapp/demo/forms.py delete mode 100644 webapp/webapp/demo/migrations/0001_initial.py delete mode 100644 webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py delete mode 100644 webapp/webapp/demo/migrations/__init__.py delete mode 100644 webapp/webapp/demo/models.py delete mode 100644 webapp/webapp/demo/static/css/annotations.css delete mode 100644 webapp/webapp/demo/static/css/base.css delete mode 100644 webapp/webapp/demo/static/css/home.css delete mode 100644 webapp/webapp/demo/static/image/favicon.ico delete mode 100644 webapp/webapp/demo/static/js/.keep delete mode 100644 webapp/webapp/demo/static/js/anns.js delete mode 100644 webapp/webapp/demo/templates/base.html delete mode 100644 webapp/webapp/demo/templates/train_annotations.html delete mode 100644 webapp/webapp/demo/templates/umls_user_validation.html delete mode 100644 webapp/webapp/demo/tests.py delete mode 100644 webapp/webapp/demo/urls.py delete mode 100644 webapp/webapp/demo/views.py delete mode 100644 webapp/webapp/etc/cron.d/db-backup-cron delete mode 100755 webapp/webapp/manage.py delete mode 100644 webapp/webapp/models/.keep delete mode 100644 webapp/webapp/requirements.txt delete mode 100644 webapp/webapp/webapp/__init__.py delete mode 100644 webapp/webapp/webapp/settings.py delete mode 100644 webapp/webapp/webapp/urls.py delete mode 100644 webapp/webapp/webapp/wsgi.py diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..9984edc16 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,95 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: '36 14 * * 0' + +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners (GitHub.com only) + # Consider using larger runners or machines with greater resources for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + # required for all workflows + security-events: write + + # required to fetch internal or private CodeQL packs + packages: read + + # only required for workflows in private repositories + actions: read + contents: read + + strategy: + fail-fast: false + matrix: + include: + - language: javascript-typescript + build-mode: none + - language: python + build-mode: none + # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' + # Use `c-cpp` to analyze code written in C, C++ or both + # Use 'java-kotlin' to analyze code written in Java, Kotlin or both + # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both + # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, + # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. + # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how + # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + # If the analyze step fails for one of the languages you are analyzing with + # "We were unable to automatically build your code", modify the matrix above + # to set the build mode to "manual" for that language. Then modify this step + # to build your code. + # ℹ️ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + - if: matrix.build-mode == 'manual' + shell: bash + run: | + echo 'If you are using a "manual" build mode for one or more of the' \ + 'languages you are analyzing, replace this with the commands to build' \ + 'your code, for example:' + echo ' make bootstrap' + echo ' make release' + exit 1 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7c7a2b742..d446160c9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -33,7 +33,32 @@ jobs: flake8 medcat - name: Test run: | - timeout 17m python -m unittest discover + all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g') + num_files=$(echo "$all_files" | wc -l) + midpoint=$((num_files / 2)) + first_half_nl=$(echo "$all_files" | head -n $midpoint) + second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1))) + timeout 25m python -m unittest ${first_half_nl[@]} + timeout 25m python -m unittest ${second_half_nl[@]} + + - name: Get the latest release version + id: get_latest_release + uses: actions/github-script@v6 + with: + script: | + const latestRelease = await github.rest.repos.getLatestRelease({ + owner: context.repo.owner, + repo: context.repo.repo + }); + core.setOutput('latest_version', latestRelease.data.tag_name); + + - name: Make sure there's no deprecated methods that should be removed. + # only run this for master -> production PR. I.e just before doing a release. + if: github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.ref == 'production' + env: + VERSION: ${{ steps.get_latest_release.outputs.latest_version }} + run: | + python tests/check_deprecations.py "$VERSION" --next-version --remove-prefix publish-to-test-pypi: @@ -43,7 +68,7 @@ jobs: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') != true runs-on: ubuntu-20.04 - timeout-minutes: 20 + timeout-minutes: 45 concurrency: publish-to-test-pypi needs: [build] diff --git a/install_requires.txt b/install_requires.txt new file mode 100644 index 000000000..da26267aa --- /dev/null +++ b/install_requires.txt @@ -0,0 +1,24 @@ +'numpy>=1.22.0,<1.26.0' # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy +'pandas>=1.4.2' # first to support 3.11 +'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump +'spacy>=3.6.0,<4.0.0' # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump +'scipy~=1.9.2' # 1.9.2 is first to support 3.11 +'transformers>=4.34.0,<5.0.0' # avoid major version bump +'accelerate>=0.23.0' # required by Trainer class in de-id +'torch>=1.13.0,<3.0.0' # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now +'tqdm>=4.27' +'scikit-learn>=1.1.3,<2.0.0' # 1.1.3 is first to supporrt 3.11; avoid major version bump +'dill>=0.3.6,<1.0.0' # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump +'datasets>=2.2.2,<3.0.0' # avoid major bump +'jsonpickle>=2.0.0' # allow later versions, tested with 3.0.0 +'psutil>=5.8.0' +# 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets +'multiprocess~=0.70.12' # 0.70.14 seemed to work just fine +'aiofiles>=0.8.0' # allow later versions, tested with 22.1.0 +'ipywidgets>=7.6.5' # allow later versions, tested with 0.8.0 +'xxhash>=3.0.0' # allow later versions, tested with 3.1.0 +'blis>=0.7.5' # allow later versions, tested with 0.7.9 +'click>=8.0.4' # allow later versions, tested with 8.1.3 +'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes +"humanfriendly~=10.0" # for human readable file / RAM sizes +"peft>=0.8.2" \ No newline at end of file diff --git a/medcat/cat.py b/medcat/cat.py index 8df7526b7..2d83ccec5 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -16,7 +16,6 @@ from datetime import date from tqdm.autonotebook import tqdm, trange from spacy.tokens import Span, Doc, Token -from spacy.language import Language import humanfriendly from medcat import __version__ @@ -37,9 +36,9 @@ from medcat.utils.meta_cat.data_utils import json_to_fake_spacy from medcat.config import Config from medcat.vocab import Vocab -from medcat.utils.decorators import deprecated from medcat.ner.transformers_ner import TransformersNER from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY +from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME from medcat.stats.stats import get_stats from medcat.utils.filters import set_project_filters @@ -49,6 +48,8 @@ HAS_NEW_SPACY = has_new_spacy() +MIN_GEN_LEN_FOR_WARN = 10_000 + class CAT(object): """The main MedCAT class used to annotate documents, it is built on top of spaCy @@ -144,15 +145,6 @@ def _create_pipeline(self, config: Config): # Set max document length self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length - @deprecated(message="Replaced with cat.pipe.spacy_nlp.") - def get_spacy_nlp(self) -> Language: - """Returns the spacy pipeline with MedCAT - - Returns: - Language: The spacy Language being used. - """ - return self.pipe.spacy_nlp - def get_hash(self, force_recalc: bool = False) -> str: """Will not be a deep hash but will try to catch all the changing parts during training. @@ -315,6 +307,12 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M with open(model_card_path, 'w') as f: json.dump(self.get_model_card(as_dict=True), f, indent=2) + # add a dependency snapshot + env_info = get_environment_info() + env_info_path = os.path.join(save_dir_path, ENV_SNAPSHOT_FILE_NAME) + with open(env_info_path, 'w') as f: + json.dump(env_info, f) + # Zip everything shutil.make_archive(os.path.join(_save_dir_path, model_pack_name), 'zip', root_dir=save_dir_path) @@ -387,12 +385,7 @@ def load_model_pack(cls, model_pack_path = cls.attempt_unpack(zip_path) # Load the CDB - cdb_path = os.path.join(model_pack_path, "cdb.dat") - nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY) - has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected - json_path = model_pack_path if has_jsons else None - logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') - cdb = CDB.load(cdb_path, json_path) + cdb: CDB = cls.load_cdb(model_pack_path) # load config config_path = os.path.join(model_pack_path, "config.json") @@ -419,11 +412,9 @@ def load_model_pack(cls, addl_ner.append(trf) # Find metacat models in the model_pack - meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else [] - meta_cats = [] - for meta_path in meta_paths: - meta_cats.append(MetaCAT.load(save_dir_path=meta_path, - config_dict=meta_cat_config_dict)) + meta_cats: List[MetaCAT] = [] + if load_meta_models: + meta_cats = [mc[1] for mc in cls.load_meta_cats(model_pack_path, meta_cat_config_dict)] # Find Rel models in model_pack rel_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('rel_')] if load_rel_models else [] @@ -436,6 +427,47 @@ def load_model_pack(cls, return cat + @classmethod + def load_cdb(cls, model_pack_path: str) -> CDB: + """ + Loads the concept database from the provided model pack path + + Args: + model_pack_path (str): path to model pack, zip or dir. + + Returns: + CDB: The loaded concept database + """ + cdb_path = os.path.join(model_pack_path, "cdb.dat") + nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY) + has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected + json_path = model_pack_path if has_jsons else None + logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') + cdb = CDB.load(cdb_path, json_path) + return cdb + + @classmethod + def load_meta_cats(cls, model_pack_path: str, meta_cat_config_dict: Optional[Dict] = None) -> List[Tuple[str, MetaCAT]]: + """ + + Args: + model_pack_path (str): path to model pack, zip or dir. + meta_cat_config_dict (Optional[Dict]): + A config dict that will overwrite existing configs in meta_cat. + e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}. + Defaults to None. + + Returns: + List[Tuple(str, MetaCAT)]: list of pairs of meta cat model names (i.e. the task name) and the MetaCAT models. + """ + meta_paths = [os.path.join(model_pack_path, path) + for path in os.listdir(model_pack_path) if path.startswith('meta_')] + meta_cats = [] + for meta_path in meta_paths: + meta_cats.append(MetaCAT.load(save_dir_path=meta_path, + config_dict=meta_cat_config_dict)) + return list(zip(meta_paths, meta_cats)) + def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]: """Push the text through the pipeline. @@ -645,13 +677,16 @@ def unlink_concept_name(self, cui: str, name: str, preprocessed_name: bool = Fal names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config) # If full unlink find all CUIs - if self.config.general.get('full_unlink', False): + if self.config.general.full_unlink: + logger.warning("In the config `full_unlink` is set to `True`. " + "Thus removing all CUIs linked to the specified name" + " (%s)", name) for n in names: cuis.extend(self.cdb.name2cuis.get(n, [])) # Remove name from all CUIs for c in cuis: - self.cdb.remove_names(cui=c, names=names) + self.cdb._remove_names(cui=c, names=names.keys()) def add_and_train_concept(self, cui: str, @@ -725,42 +760,6 @@ def add_and_train_concept(self, for _cui in cuis: self.linker.context_model.train(cui=_cui, entity=spacy_entity, doc=spacy_doc, negative=True) # type: ignore - @deprecated(message="Use train_supervised_from_json to train based on data " - "loaded from a json file") - def train_supervised(self, - data_path: str, - reset_cui_count: bool = False, - nepochs: int = 1, - print_stats: int = 0, - use_filters: bool = False, - terminate_last: bool = False, - use_overlaps: bool = False, - use_cui_doc_limit: bool = False, - test_size: int = 0, - devalue_others: bool = False, - use_groups: bool = False, - never_terminate: bool = False, - train_from_false_positives: bool = False, - extra_cui_filter: Optional[Set] = None, - retain_extra_cui_filter: bool = False, - checkpoint: Optional[Checkpoint] = None, - retain_filters: bool = False, - is_resumed: bool = False) -> Tuple: - """Train supervised by reading data from a json file. - - Refer to `train_supervvised_from_json` and/or `train_supervised_raw` - for further details. - - # noqa: DAR101 - # noqa: DAR201 - """ - return self.train_supervised_from_json(data_path, reset_cui_count, nepochs, - print_stats, use_filters, terminate_last, - use_overlaps, use_cui_doc_limit, test_size, - devalue_others, use_groups, never_terminate, - train_from_false_positives, extra_cui_filter, - retain_extra_cui_filter, checkpoint, - retain_filters, is_resumed) def train_supervised_from_json(self, data_path: str, @@ -1226,25 +1225,6 @@ def _save_docs_to_file(self, docs: Iterable, annotated_ids: List[str], save_dir_ pickle.dump((annotated_ids, part_counter), open(annotated_ids_path, 'wb')) return part_counter - @deprecated(message="Use `multiprocessing_batch_char_size` instead") - def multiprocessing(self, - data: Union[List[Tuple], Iterable[Tuple]], - nproc: int = 2, - batch_size_chars: int = 5000 * 1000, - only_cui: bool = False, - addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'], - separate_nn_components: bool = True, - out_split_size_chars: Optional[int] = None, - save_dir_path: str = os.path.abspath(os.getcwd()), - min_free_memory=0.1) -> Dict: - return self.multiprocessing_batch_char_size(data=data, nproc=nproc, - batch_size_chars=batch_size_chars, - only_cui=only_cui, addl_info=addl_info, - separate_nn_components=separate_nn_components, - out_split_size_chars=out_split_size_chars, - save_dir_path=save_dir_path, - min_free_memory=min_free_memory) - def multiprocessing_batch_char_size(self, data: Union[List[Tuple], Iterable[Tuple]], nproc: int = 2, @@ -1499,21 +1479,6 @@ def _multiprocessing_batch(self, return docs - @deprecated(message="Use `multiprocessing_batch_docs_size` instead") - def multiprocessing_pipe(self, in_data: Union[List[Tuple], Iterable[Tuple]], - nproc: Optional[int] = None, - batch_size: Optional[int] = None, - only_cui: bool = False, - addl_info: List[str] = [], - return_dict: bool = True, - batch_factor: int = 2) -> Union[List[Tuple], Dict]: - return self.multiprocessing_batch_docs_size(in_data=in_data, nproc=nproc, - batch_size=batch_size, - only_cui=only_cui, - addl_info=addl_info, - return_dict=return_dict, - batch_factor=batch_factor) - def multiprocessing_batch_docs_size(self, in_data: Union[List[Tuple], Iterable[Tuple]], nproc: Optional[int] = None, @@ -1526,6 +1491,11 @@ def multiprocessing_batch_docs_size(self, This method batches the data based on the number of documents as specified by the user. + NOTE: When providing a generator for `data`, the generator is evaluated (`list(in_data)`) + and thus all the data is kept in memory and (potentially) duplicated for use in + multiple threads. So if you're using a lot of data, it may be better to use + `CAT.multiprocessing_batch_char_size` instead. + PS: This method supports Windows. @@ -1550,7 +1520,20 @@ def multiprocessing_batch_docs_size(self, if nproc == 0: raise ValueError("nproc cannot be set to zero") - in_data = list(in_data) if isinstance(in_data, Iterable) else in_data + # TODO: Surely there's a way to not materialise all of the incoming data in memory? + # This is counter productive for allowing the passing of generators. + if isinstance(in_data, Iterable): + in_data = list(in_data) + in_data_len = len(in_data) + if in_data_len > MIN_GEN_LEN_FOR_WARN: + # only point this out when it's relevant, i.e over 10k items + logger.warning("The `CAT.multiprocessing_batch_docs_size` method just " + f"materialised {in_data_len} items from the generator it " + "was provided. This may use up a considerable amount of " + "RAM, especially since the data may be duplicated across " + "multiple threads when multiprocessing is used. If the " + "process is kiled after this warning, please use the " + "alternative method `multiprocessing_batch_char_size` instead") n_process = nproc if nproc is not None else min(max(cpu_count() - 1, 1), math.ceil(len(in_data) / batch_factor)) batch_size = batch_size if batch_size is not None else math.ceil(len(in_data) / (batch_factor * abs(n_process))) diff --git a/medcat/cdb.py b/medcat/cdb.py index 6ae15d3f5..e63843364 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -5,17 +5,20 @@ import logging import aiofiles import numpy as np -from typing import Dict, Set, Optional, List, Union, cast -from functools import partial +from typing import Dict, Set, Optional, List, Union, cast, Iterable import os from medcat import __version__ from medcat.utils.hasher import Hasher from medcat.utils.matutils import unitvec from medcat.utils.ml_utils import get_lr_linking +from medcat.config import Config, workers from medcat.utils.decorators import deprecated -from medcat.config import Config, weighted_average, workers from medcat.utils.saving.serializer import CDBSerializer +from medcat.utils.config_utils import get_and_del_weighted_average_from_config +from medcat.utils.config_utils import default_weighted_average +from medcat.utils.config_utils import ensure_backward_compatibility +from medcat.utils.config_utils import fix_waf_lambda, attempt_fix_weighted_average_function logger = logging.getLogger(__name__) @@ -98,6 +101,7 @@ def __init__(self, config: Union[Config, None] = None) -> None: self.vocab: Dict = {} # Vocabulary of all words ever in our cdb self._optim_params = None self.is_dirty = False + self._init_waf_from_config() self._hash: Optional[str] = None # the config hash is kept track of here so that # the CDB hash can be re-calculated when the config changes @@ -107,6 +111,18 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._config_hash: Optional[str] = None self._memory_optimised_parts: Set[str] = set() + def _init_waf_from_config(self): + waf = get_and_del_weighted_average_from_config(self.config) + if waf is not None: + logger.info("Using (potentially) custom value of weighed " + "average function") + self.weighted_average_function = attempt_fix_weighted_average_function(waf) + elif hasattr(self, 'weighted_average_function'): + # keep existing + pass + else: + self.weighted_average_function = default_weighted_average + def get_name(self, cui: str) -> str: """Returns preferred name if it exists, otherwise it will return the longest name assigned to the concept. @@ -132,7 +148,12 @@ def update_cui2average_confidence(self, cui: str, new_sim: float) -> None: (self.cui2count_train.get(cui, 0) + 1) self.is_dirty = True - def remove_names(self, cui: str, names: Dict[str, Dict]) -> None: + @deprecated("Deprecated. For internal use only. Use CAT.unlink_concept_name instead", + depr_version=(1, 12, 0), removal_version=(1, 13, 0)) + def remove_names(self, cui: str, names: Iterable[str]) -> None: + self._remove_names(cui, names) + + def _remove_names(self, cui: str, names: Iterable[str]) -> None: """Remove names from an existing concept - effect is this name will never again be used to link to this concept. This will only remove the name from the linker (namely name2cuis and name2cuis2status), the name will still be present everywhere else. Why? Because it is bothersome to remove it from everywhere, but @@ -141,10 +162,10 @@ def remove_names(self, cui: str, names: Dict[str, Dict]) -> None: Args: cui (str): Concept ID or unique identifer in this database. - names (Dict[str, Dict]): - Names to be removed, should look like: `{'name': {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}` + names (Iterable[str]): + Names to be removed (e.g list, set, or even a dict (in which case keys will be used)). """ - for name in names.keys(): + for name in names: if name in self.name2cuis: if cui in self.name2cuis[name]: self.name2cuis[name].remove(cui) @@ -231,44 +252,6 @@ def add_names(self, cui: str, names: Dict[str, Dict], name_status: str = 'A', fu self._add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build) - @deprecated("Use `cdb._add_concept` as this will be removed in a future release.") - def add_concept(self, - cui: str, - names: Dict[str, Dict], - ontologies: Set[str], - name_status: str, - type_ids: Set[str], - description: str, - full_build: bool = False) -> None: - """ - Deprecated: Use `cdb._add_concept` as this will be removed in a future release. - - Add a concept to internal Concept Database (CDB). Depending on what you are providing - this will add a large number of properties for each concept. - - Args: - cui (str): - Concept ID or unique identifier in this database, all concepts that have - the same CUI will be merged internally. - names (Dict[str, Dict]): - Names for this concept, or the value that if found in free text can be linked to this concept. - Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}` - Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name' - ontologies (Set[str]): - ontologies in which the concept exists (e.g. SNOMEDCT, HPO) - name_status (str): - One of `P`, `N`, `A` - type_ids (Set[str]): - Semantic type identifier (have a look at TUIs in UMLS or SNOMED-CT) - description (str): - Description of this concept. - full_build (bool): - If True the dictionary self.addl_info will also be populated, contains a lot of extra information - about concepts, but can be very memory consuming. This is not necessary - for normal functioning of MedCAT (Default Value `False`). - """ - self._add_concept(cui, names, ontologies, name_status, type_ids, description, full_build) - def _add_concept(self, cui: str, names: Dict[str, Dict], @@ -558,6 +541,8 @@ def load_config(self, config_path: str) -> None: # this should be the behaviour for all newer models self.config = cast(Config, Config.load(config_path)) logger.debug("Loaded config from CDB from %s", config_path) + # new config, potentially new weighted_average_function to read + self._init_waf_from_config() # mark config read from file self._config_from_file = True @@ -582,7 +567,8 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[ ser = CDBSerializer(path, json_path) cdb = ser.deserialize(CDB) cls._check_medcat_version(cdb.config.asdict()) - cls._ensure_backward_compatibility(cdb.config) + fix_waf_lambda(cdb) + ensure_backward_compatibility(cdb.config, workers) # Overwrite the config with new data if config_dict is not None: @@ -855,19 +841,6 @@ def most_similar(self, return res - @staticmethod - def _ensure_backward_compatibility(config: Config) -> None: - # Hacky way of supporting old CDBs - weighted_average_function = config.linking.weighted_average_function - if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "": - # the following type ignoring is for mypy because it is unable to detect the signature - config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore - if config.general.workers is None: - config.general.workers = workers() - disabled_comps = config.general.spacy_disabled_components - if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps: - config.general.spacy_disabled_components.append('lemmatizer') - @classmethod def _check_medcat_version(cls, config_data: Dict) -> None: cdb_medcat_version = config_data.get('version', {}).get('medcat_version', None) diff --git a/medcat/config.py b/medcat/config.py index 88c4aad14..cdb30d0fe 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -1,17 +1,19 @@ from datetime import datetime from pydantic import BaseModel, Extra, ValidationError from pydantic.fields import ModelField -from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union +from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type from multiprocessing import cpu_count import logging import jsonpickle +import json from functools import partial import re from medcat.utils.hasher import Hasher from medcat.utils.matutils import intersect_nonempty_set from medcat.utils.config_utils import attempt_fix_weighted_average_function -from medcat.utils.config_utils import weighted_average +from medcat.utils.config_utils import weighted_average, is_old_type_config_dict +from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ def __getitem__(self, arg: str) -> Any: raise KeyError from e def __setattr__(self, arg: str, val) -> None: + # TODO: remove this in the future when we stop stupporting this in config if isinstance(self, Linking) and arg == "weighted_average_function": val = attempt_fix_weighted_average_function(val) super().__setattr__(arg, val) @@ -103,8 +106,8 @@ def save(self, save_path: str) -> None: save_path(str): Where to save the created json file """ # We want to save the dict here, not the whole class - json_string = jsonpickle.encode( - {field: getattr(self, field) for field in self.fields()}) + json_string = json.dumps(self.asdict(), cls=cast(Type[json.JSONEncoder], + CustomDelegatingEncoder.def_inst)) with open(save_path, 'w') as f: f.write(json_string) @@ -204,7 +207,12 @@ def load(cls, save_path: str) -> "MixingConfig": # Read the jsonpickle string with open(save_path) as f: - config_dict = jsonpickle.decode(f.read()) + config_dict = json.load(f, object_hook=default_hook) + if is_old_type_config_dict(config_dict): + logger.warning("Loading an old type of config (jsonpickle) from '%s'", + save_path) + with open(save_path) as f: + config_dict = jsonpickle.decode(f.read()) config.merge_config(config_dict) @@ -511,9 +519,6 @@ class Linking(MixingConfig, BaseModel): similarity calculation and will have a similarity of -1.""" always_calculate_similarity: bool = False """Do we want to calculate context similarity even for concepts that are not ambigous.""" - weighted_average_function: Callable[..., Any] = _DEFAULT_PARTIAL - """Weights for a weighted average - 'weighted_average_function': partial(weighted_average, factor=0.02),""" calculate_dynamic_threshold: bool = False """Concepts below this similarity will be ignored. Type can be static/dynamic - if dynamic each CUI has a different TH and it is calcualted as the average confidence for that CUI * similarity_threshold. Take care that dynamic works only @@ -597,3 +602,39 @@ def get_hash(self): hasher.update(v2, length=True) self.hash = hasher.hexdigest() return self.hash + + +class UseOfOldConfigOptionException(AttributeError): + + def __init__(self, conf_type: Type[FakeDict], arg_name: str, advice: str) -> None: + super().__init__(f"Tried to use {conf_type.__name__}.{arg_name}. " + f"Advice: {advice}") + self.conf_type = conf_type + self.arg_name = arg_name + self.advice = advice + + +# NOTE: The following is for backwards compatibility and should be removed +# at some point in the future + +# wrapper for functions for a better error in case of weighted_average_function +# access +def _wrapper(func, check_type: Type[FakeDict], advice: str, exp_type: Type[Exception]): + def wrapper(*args, **kwargs): + try: + res = func(*args, **kwargs) + except exp_type as ex: + if ((len(args) == 2 and len(kwargs) == 0) and + (isinstance(args[0], check_type) and + args[1] == "weighted_average_function")): + raise UseOfOldConfigOptionException(Linking, args[1], advice) from ex + raise ex + return res + return wrapper + + +# wrap Linking.__getattribute__ so that when getting weighted_average_function +# we get a nicer exceptio +_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly" +Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore +Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index 6ddd71d56..686029052 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -1,5 +1,4 @@ from typing import Dict, Any - from medcat.config import MixingConfig, BaseModel, Optional, Extra @@ -49,10 +48,20 @@ class Config: class Model(MixingConfig, BaseModel): """The model part of the metaCAT config""" model_name: str = 'lstm' + """NOTE: When changing model, make sure to change the tokenizer as well""" + model_variant: str = 'bert-base-uncased' + model_freeze_layers: bool = True num_layers: int = 2 input_size: int = 300 hidden_size: int = 300 dropout: float = 0.5 + phase_number: int = 0 + """Indicates whether or not two phase learning is being performed. + 1: Phase 1 - Train model on undersampled data + 2: Phase 2 - Continue training on full data + 0: None - 2 phase learning is not performed""" + category_undersample: str = '' + model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True} num_directions: int = 2 """2 - bidirectional model, 1 - unidirectional""" nclasses: int = 2 @@ -61,7 +70,7 @@ class Model(MixingConfig, BaseModel): emb_grad: bool = True """If True the embeddings will also be trained""" ignore_cpos: bool = False - """If set to True center positions will be ignored when calculating represenation""" + """If set to True center positions will be ignored when calculating representation""" class Config: extra = Extra.allow @@ -77,6 +86,8 @@ class Train(MixingConfig, BaseModel): shuffle_data: bool = True """Used only during training, if set the dataset will be shuffled before train/test split""" class_weights: Optional[Any] = None + compute_class_weights: bool = False + """If true and if class weights are not provided, the class weights will be calculated based on the data""" score_average: str = 'weighted' """What to use for averaging F1/P/R across labels""" prerequisites: dict = {} @@ -88,6 +99,10 @@ class Train(MixingConfig, BaseModel): """When was the last training run""" metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'} """What metric should be used for choosing the best model""" + loss_funct: str = 'cross_entropy' + """Loss function for the model""" + gamma: int = 2 + """Focal Loss - how much the loss focuses on hard-to-classify examples.""" class Config: extra = Extra.allow diff --git a/medcat/linking/vector_context_model.py b/medcat/linking/vector_context_model.py index 7c4c11a69..e4875c32f 100644 --- a/medcat/linking/vector_context_model.py +++ b/medcat/linking/vector_context_model.py @@ -71,7 +71,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict: values = [] # Add left - values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_) + values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_) for step, tkn in enumerate(tokens_left) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) if not self.config.linking['context_ignore_center_tokens']: @@ -83,7 +83,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict: values.extend([self.vocab.vec(tkn.lower_) for tkn in tokens_center if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) # Add right - values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_) + values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_) for step, tkn in enumerate(tokens_right) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) if len(values) > 0: diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 78fcd9982..349b848ed 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -11,18 +11,17 @@ from medcat.utils.hasher import Hasher from medcat.config_meta_cat import ConfigMetaCAT from medcat.utils.meta_cat.ml_utils import predict, train_model, set_all_seeds, eval_model -from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values +from medcat.utils.meta_cat.data_utils import prepare_from_json, encode_category_values, prepare_for_oversampled_data from medcat.pipeline.pipe_runner import PipeRunner from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase from medcat.utils.meta_cat.data_utils import Doc as FakeDoc -from medcat.utils.decorators import deprecated +from peft import get_peft_model, LoraConfig, TaskType # It should be safe to do this always, as all other multiprocessing # will be finished before data comes to meta_cat os.environ["TOKENIZERS_PARALLELISM"] = "true" - -logger = logging.getLogger(__name__) # separate logger from the package-level one +logger = logging.getLogger(__name__) # separate logger from the package-level one class MetaCAT(PipeRunner): @@ -77,7 +76,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: The embedding densor Raises: - ValueError: If the meta model is not LSTM + ValueError: If the meta model is not LSTM or BERT Returns: nn.Module: @@ -86,7 +85,22 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: config = self.config if config.model['model_name'] == 'lstm': from medcat.utils.meta_cat.models import LSTM - model = LSTM(embeddings, config) + model: nn.Module = LSTM(embeddings, config) + logger.info("LSTM model used for classification") + + elif config.model['model_name'] == 'bert': + from medcat.utils.meta_cat.models import BertForMetaAnnotation + model = BertForMetaAnnotation(config) + + if not config.model.model_freeze_layers: + peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16, + target_modules=["query", "value"], lora_dropout=0.2) + + model = get_peft_model(model, peft_config) + # model.print_trainable_parameters() + + logger.info("BERT model used for classification") + else: raise ValueError("Unknown model name %s" % config.model['model_name']) @@ -106,24 +120,8 @@ def get_hash(self) -> str: hasher.update(self.config.get_hash()) return hasher.hexdigest() - @deprecated(message="Use `train_from_json` or `train_raw` instead") - def train(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: - """Train or continue training a model give a json_path containing a MedCATtrainer export. It will - continue training if an existing model is loaded or start new training if the model is blank/new. - - Args: - json_path (Union[str, list]): - Path/Paths to a MedCATtrainer export containing the meta_annotations we want to train for. - save_dir_path (Optional[str]): - In case we have aut_save_model (meaning during the training the best model will be saved) - we need to set a save path. Defaults to `None`. - - Returns: - Dict: The resulting report. - """ - return self.train_from_json(json_path, save_dir_path) - - def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None) -> Dict: + def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[str] = None, + data_oversampled: Optional[list] = None) -> Dict: """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -133,6 +131,8 @@ def train_from_json(self, json_path: Union[str, list], save_dir_path: Optional[s save_dir_path (Optional[str]): In case we have aut_save_model (meaning during the training the best model will be saved) we need to set a save path. Defaults to `None`. + data_oversampled (Optional[list]): + In case of oversampling being performed, the data will be passed in the parameter Returns: Dict: The resulting report. @@ -157,9 +157,9 @@ def merge_data_loaded(base, other): for path in json_path: with open(path, 'r') as f: data_loaded = merge_data_loaded(data_loaded, json.load(f)) - return self.train_raw(data_loaded, save_dir_path) + return self.train_raw(data_loaded, save_dir_path, data_oversampled=data_oversampled) - def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> Dict: + def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data_oversampled: Optional[list] = None) -> Dict: """Train or continue training a model given raw data. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -187,6 +187,10 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D save_dir_path (Optional[str]): In case we have aut_save_model (meaning during the training the best model will be saved) we need to set a save path. Defaults to `None`. + data_oversampled (Optional[list]): + In case of oversampling being performed, the data will be passed in the parameter + The format of which is expected: [[['text','of','the','document'], [index of medical entity], "label" ], + ['text','of','the','document'], [index of medical entity], "label" ]] Returns: Dict: The resulting report. @@ -194,6 +198,8 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D Raises: Exception: If no save path is specified, or category name not in data. AssertionError: If no tokeniser is set + FileNotFoundError: If phase_number is set to 2 and model.dat file is not found + KeyError: If phase_number is set to 2 and model.dat file contains mismatched architecture """ g_config = self.config.general t_config = self.config.train @@ -212,7 +218,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'], lowercase=g_config['lowercase']) - # Check is the name there + # Check is the name present category_name = g_config['category_name'] if category_name not in data: raise Exception( @@ -220,16 +226,22 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D category_name, " | ".join(list(data.keys())))) data = data[category_name] + if data_oversampled: + data_sampled = prepare_for_oversampled_data(data_oversampled, self.tokenizer) + data = data + data_sampled category_value2id = g_config['category_value2id'] if not category_value2id: # Encode the category values - data, category_value2id = encode_category_values(data) + data_undersampled, full_data, category_value2id = encode_category_values(data, + category_undersample=self.config.model.category_undersample) g_config['category_value2id'] = category_value2id else: # We already have everything, just get the data - data, _ = encode_category_values(data, existing_category_value2id=category_value2id) - + data_undersampled, full_data, category_value2id = encode_category_values(data, + existing_category_value2id=category_value2id, + category_undersample=self.config.model.category_undersample) + g_config['category_value2id'] = category_value2id # Make sure the config number of classes is the same as the one found in the data if len(category_value2id) != self.config.model['nclasses']: logger.warning( @@ -237,7 +249,29 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None) -> D self.config.model['nclasses'], len(category_value2id))) logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") self.config.model['nclasses'] = len(category_value2id) - self.model = self.get_model(embeddings=self.embeddings) + + if self.config.model.phase_number == 2 and save_dir_path is not None: + model_save_path = os.path.join(save_dir_path, 'model.dat') + device = torch.device(g_config['device']) + try: + self.model.load_state_dict(torch.load(model_save_path, map_location=device)) + logger.info("Model state loaded from dict for 2 phase learning") + + except FileNotFoundError: + raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.") + + except KeyError: + raise KeyError("\nError: Missing key in loaded state dictionary. \nThis might be due to a mismatch between the model architecture and the saved state.") + + except Exception as e: + raise Exception(f"\nError: Model state cannot be loaded from dict. {e}") + + data = full_data + if self.config.model.phase_number == 1: + data = data_undersampled + if not t_config['auto_save_model']: + logger.info("For phase 1, model state has to be saved. Saving model...") + t_config['auto_save_model'] = True report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path) @@ -293,7 +327,7 @@ def eval(self, json_path: str) -> Dict: # We already have everything, just get the data category_value2id = g_config['category_value2id'] - data, _ = encode_category_values(data, existing_category_value2id=category_value2id) + data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id) # Run evaluation assert self.tokenizer is not None @@ -317,8 +351,8 @@ def save(self, save_dir_path: str) -> None: # Save tokenizer assert self.tokenizer is not None self.tokenizer.save(save_dir_path) - # Save config + self.config.save(os.path.join(save_dir_path, 'config.json')) # Save the model @@ -347,7 +381,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA # Load config config = cast(ConfigMetaCAT, ConfigMetaCAT.load(os.path.join(save_dir_path, 'config.json'))) - # Overwrite loaded paramters with something new + # Overwrite loaded parameters with something new if config_dict is not None: config.merge_config(config_dict) @@ -358,7 +392,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA tokenizer = TokenizerWrapperBPE.load(save_dir_path) elif config.general['tokenizer_name'] == 'bert-tokenizer': from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT - tokenizer = TokenizerWrapperBERT.load(save_dir_path) + tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant']) # Create meta_cat meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) @@ -380,7 +414,8 @@ def get_ents(self, doc: Doc) -> Iterable[Span]: try: return doc.spans[spangroup_name] except KeyError: - raise Exception(f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.") + raise Exception( + f"Configuration error MetaCAT was configured to set meta_anns on {spangroup_name} but this spangroup was not set on the doc.") # Should we annotate overlapping entities if self.config.general['annotate_overlapping']: @@ -421,18 +456,26 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe start = ent.start_char end = ent.end_char - ind = 0 - # Start where the last ent was found, cannot be before it as we've sorted + # Updated implementation to extract all the tokens for the medical entity (rather than the one) + ctoken_idx = [] for ind, pair in enumerate(offset_mapping[last_ind:]): - if start >= pair[0] and start < pair[1]: - break - ind = last_ind + ind # If we did not start from 0 in the for loop - last_ind = ind + # Checking if we've reached at the start of the entity + if start <= pair[0] or start <= pair[1]: + if end <= pair[1]: + ctoken_idx.append(ind) # End reached + break + else: + ctoken_idx.append(ind) # Keep going + + # Start where the last ent was found, cannot be before it as we've sorted + last_ind += ind # If we did not start from 0 in the for loop + + _start = max(0, ctoken_idx[0] - cntx_left) + _end = min(len(input_ids), ctoken_idx[-1] + 1 + cntx_right) - _start = max(0, ind - cntx_left) - _end = min(len(input_ids), ind + 1 + cntx_right) tkns = input_ids[_start:_end] cpos = cntx_left + min(0, ind - cntx_left) + cpos_new = [x - _start for x in ctoken_idx] if replace_center is not None: if lowercase: @@ -447,8 +490,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe ln = e_ind - s_ind # Length of the concept in tokens assert self.tokenizer is not None tkns = tkns[:cpos] + self.tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:] - - samples.append([tkns, cpos]) + samples.append([tkns, cpos_new]) ent_id2ind[ent._.id] = len(samples) - 1 return ent_id2ind, samples @@ -544,7 +586,6 @@ def _set_meta_anns(self, for i, doc in enumerate(docs): data.extend(doc._.share_tokens[0]) doc_ind2positions[i] = doc._.share_tokens[1] - all_predictions, all_confidences = predict(self.model, data, config) for i, doc in enumerate(docs): start_ind, end_ind, ent_id2ind = doc_ind2positions[i] diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 7aabceda2..9d4700df9 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -4,8 +4,10 @@ import datasets from spacy.tokens import Doc from datetime import datetime -from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple +from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable from spacy.tokens import Span +import inspect +from functools import partial from medcat.cdb import CDB from medcat.utils.meta_cat.ml_utils import set_all_seeds @@ -178,10 +180,21 @@ def train(self, json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels, meta_requirements=meta_requirements, file_name='data_eval.json') # Load dataset - dataset = datasets.load_dataset(os.path.abspath(transformers_ner.__file__), - data_files={'train': json_path}, # type: ignore - split='train', - cache_dir='/tmp/') + + # NOTE: The following is for backwards comppatibility + # in datasets==2.20.0 `trust_remote_code=True` must be explicitly + # specified, otherwise an error is raised. + # On the other hand, the keyword argumnet was added in datasets==2.16.0 + # yet we support datasets>=2.2.0. + # So we need to use the kwarg if applicable and omit its use otherwise. + if func_has_kwarg(datasets.load_dataset, 'trust_remote_code'): + ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True) + else: + ds_load_dataset = datasets.load_dataset + dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__), + data_files={'train': json_path}, # type: ignore + split='train', + cache_dir='/tmp/') # We split before encoding so the split is document level, as encoding #does the document spliting into max_seq_len dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore @@ -422,3 +435,9 @@ def __call__(self, doc: Doc) -> Doc: doc = next(self.pipe(iter([doc]))) return doc + + +# NOTE: Only needed for datasets backwards compatibility +def func_has_kwarg(func: Callable, keyword: str): + sig = inspect.signature(func) + return keyword in sig.parameters diff --git a/medcat/stats/kfold.py b/medcat/stats/kfold.py new file mode 100644 index 000000000..491173c23 --- /dev/null +++ b/medcat/stats/kfold.py @@ -0,0 +1,436 @@ +from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any + +from abc import ABC, abstractmethod +from enum import Enum, auto +from copy import deepcopy + +import numpy as np + +from medcat.utils.checkpoint import Checkpoint +from medcat.utils.cdb_state import captured_state_cdb + +from medcat.stats.stats import get_stats +from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportProject +from medcat.stats.mctexport import MedCATTrainerExportDocument, MedCATTrainerExportAnnotation +from medcat.stats.mctexport import count_all_annotations, count_all_docs, get_nr_of_annotations +from medcat.stats.mctexport import iter_anns, iter_docs, MedCATTrainerExportProjectInfo + + + +class CDBLike(Protocol): + pass + + +class CATLike(Protocol): + + @property + def cdb(self) -> CDBLike: + pass + + def train_supervised_raw(self, + data: Dict[str, List[Dict[str, dict]]], + reset_cui_count: bool = False, + nepochs: int = 1, + print_stats: int = 0, + use_filters: bool = False, + terminate_last: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + test_size: float = 0, + devalue_others: bool = False, + use_groups: bool = False, + never_terminate: bool = False, + train_from_false_positives: bool = False, + extra_cui_filter: Optional[Set] = None, + retain_extra_cui_filter: bool = False, + checkpoint: Optional[Checkpoint] = None, + retain_filters: bool = False, + is_resumed: bool = False) -> Tuple: + pass + + +class SplitType(Enum): + """The split type.""" + DOCUMENTS = auto() + """Split over number of documents.""" + ANNOTATIONS = auto() + """Split over number of annotations.""" + DOCUMENTS_WEIGHTED = auto() + """Split over number of documents based on the number of annotations. + So essentially this ensures that the same document isn't in 2 folds + while trying to more equally distribute documents with different number + of annotations. + For example: + If we have 6 documents that we want to split into 3 folds. + The number of annotations per document are as follows: + [40, 40, 20, 10, 5, 5] + If we were to split this trivially over documents, we'd end up + with the 3 folds with number of annotations that are far from even: + [80, 30, 10] + However, if we use the annotations as weights, we would be able to + create folds that have more evenly distributed annotations, e.g: + [[D1,], [D2], [D3, D4, D5, D6]] + where D# denotes the number of the documents, with the number of + annotations being equal: + [ 40, 40, 20 + 10 + 5 + 5 = 40] + """ + + +class FoldCreator(ABC): + """The FoldCreator based on a MCT export. + + Args: + mct_export (MedCATTrainerExport): The MCT export dict. + nr_of_folds (int): Number of folds to create. + use_annotations (bool): Whether to fold on number of annotations or documents. + """ + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + self.mct_export = mct_export + self.nr_of_folds = nr_of_folds + + def _find_or_add_doc(self, project: MedCATTrainerExportProject, orig_doc: MedCATTrainerExportDocument + ) -> MedCATTrainerExportDocument: + for existing_doc in project['documents']: + if existing_doc['name'] == orig_doc['name']: + return existing_doc + new_doc: MedCATTrainerExportDocument = deepcopy(orig_doc) + new_doc['annotations'].clear() + project['documents'].append(new_doc) + return new_doc + + def _create_new_project(self, proj_info: MedCATTrainerExportProjectInfo) -> MedCATTrainerExportProject: + (proj_name, proj_id, proj_cuis, proj_tuis) = proj_info + cur_project = cast(MedCATTrainerExportProject, { + 'name': proj_name, + 'id': proj_id, + 'cuis': proj_cuis, + 'documents': [], + }) + # NOTE: Some MCT exports don't declare TUIs + if proj_tuis is not None: + cur_project['tuis'] = proj_tuis + return cur_project + + def _create_export_with_documents(self, relevant_docs: Iterable[Tuple[MedCATTrainerExportProjectInfo, + MedCATTrainerExportDocument]]) -> MedCATTrainerExport: + export: MedCATTrainerExport = { + "projects": [] + } + # helper for finding projects per name + used_projects: Dict[str, MedCATTrainerExportProject] = {} + for proj_info, doc in relevant_docs: + proj_name = proj_info[0] + if proj_name not in used_projects: + cur_project = self._create_new_project(proj_info) # TODO - make sure it's available + export['projects'].append(cur_project) + used_projects[proj_name] = cur_project + else: + cur_project = used_projects[proj_name] + cur_project['documents'].append(doc) + return export + + + @abstractmethod + def create_folds(self) -> List[MedCATTrainerExport]: + """Create folds. + + Raises: + ValueError: If somethign went wrong. + + Returns: + List[MedCATTrainerExport]: The created folds. + """ + + +class SimpleFoldCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int, + counter: Callable[[MedCATTrainerExport], int]) -> None: + super().__init__(mct_export, nr_of_folds) + self._counter = counter + self.total = self._counter(mct_export) + self.per_fold = self._init_per_fold() + + def _init_per_fold(self) -> List[int]: + per_fold = [self.total // self.nr_of_folds for _ in range(self.nr_of_folds)] + total = sum(per_fold) + if total < self.total: + per_fold[-1] += self.total - total + if any(pf <= 0 for pf in per_fold): + raise ValueError(f"Failed to calculate per-fold items. Got: {per_fold}") + return per_fold + + @abstractmethod + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + pass + + def create_folds(self) -> List[MedCATTrainerExport]: + return [ + self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds) + ] + + + +class PerDocsFoldCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + super().__init__(mct_export, nr_of_folds) + self.nr_of_docs = count_all_docs(self.mct_export) + self.per_doc_simple = self.nr_of_docs // self.nr_of_folds + self._all_docs = list(iter_docs(self.mct_export)) + + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + start_nr = self.per_doc_simple * fold_nr + # until the end for last fold, otherwise just the next set of docs + end_nr = self.nr_of_docs if fold_nr == self.nr_of_folds - 1 else start_nr + self.per_doc_simple + relevant_docs = self._all_docs[start_nr: end_nr] + return self._create_export_with_documents(relevant_docs) + + def create_folds(self) -> List[MedCATTrainerExport]: + return [ + self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds) + ] + + +class PerAnnsFoldCreator(SimpleFoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int) -> None: + super().__init__(mct_export, nr_of_folds, count_all_annotations) + + def _add_target_ann(self, project: MedCATTrainerExportProject, + orig_doc: MedCATTrainerExportDocument, + ann: MedCATTrainerExportAnnotation) -> None: + cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc) + cur_doc['annotations'].append(ann) + + def _targets(self) -> Iterable[Tuple[MedCATTrainerExportProjectInfo, + MedCATTrainerExportDocument, + MedCATTrainerExportAnnotation]]: + return iter_anns(self.mct_export) + + def _create_fold(self, fold_nr: int) -> MedCATTrainerExport: + per_fold = self.per_fold[fold_nr] + cur_fold: MedCATTrainerExport = { + 'projects': [] + } + cur_project: Optional[MedCATTrainerExportProject] = None + included = 0 + for target in self._targets(): + proj_info, cur_doc, cur_ann = target + proj_name = proj_info[0] + if not cur_project or cur_project['name'] != proj_name: + # first or new project + cur_project = self._create_new_project(proj_info) + cur_fold['projects'].append(cur_project) + self._add_target_ann(cur_project, cur_doc, cur_ann) + included += 1 + if included == per_fold: + break + if included > per_fold: + raise ValueError("Got a larger fold than expected. " + f"Expected {per_fold}, got {included}") + return cur_fold + + +class WeightedDocumentsCreator(FoldCreator): + + def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int, + weight_calculator: Callable[[MedCATTrainerExportDocument], int]) -> None: + super().__init__(mct_export, nr_of_folds) + self._weight_calculator = weight_calculator + docs = [(doc, self._weight_calculator(doc[1])) for doc in iter_docs(self.mct_export)] + # descending order in weight + self._weighted_docs = sorted(docs, key=lambda d: d[1], reverse=True) + + def create_folds(self) -> List[MedCATTrainerExport]: + doc_folds: List[List[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]] + doc_folds = [[] for _ in range(self.nr_of_folds)] + fold_weights = [0] * self.nr_of_folds + + for item, weight in self._weighted_docs: + # Find the subset with the minimum total weight + min_subset_idx = np.argmin(fold_weights) + # add the most heavily weighted document + doc_folds[min_subset_idx].append(item) + fold_weights[min_subset_idx] += weight + + return [self._create_export_with_documents(docs) for docs in doc_folds] + + +def get_fold_creator(mct_export: MedCATTrainerExport, + nr_of_folds: int, + split_type: SplitType) -> FoldCreator: + """Get the appropriate fold creator. + + Args: + mct_export (MedCATTrainerExport): The MCT export. + nr_of_folds (int): Number of folds to use. + split_type (SplitType): The type of split to use. + + Raises: + ValueError: In case of an unknown split type. + + Returns: + FoldCreator: The corresponding fold creator. + """ + if split_type is SplitType.DOCUMENTS: + return PerDocsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds) + elif split_type is SplitType.ANNOTATIONS: + return PerAnnsFoldCreator(mct_export=mct_export, nr_of_folds=nr_of_folds) + elif split_type is SplitType.DOCUMENTS_WEIGHTED: + return WeightedDocumentsCreator(mct_export=mct_export, nr_of_folds=nr_of_folds, + weight_calculator=get_nr_of_annotations) + else: + raise ValueError(f"Unknown Split Type: {split_type}") + + +def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport], + *args, **kwargs) -> List[Tuple]: + metrics = [] + for fold_nr, cur_fold in enumerate(folds): + others = list(folds) + others.pop(fold_nr) + with captured_state_cdb(cat.cdb): + for other in others: + cat.train_supervised_raw(cast(Dict[str, Any], other), *args, **kwargs) + stats = get_stats(cat, cast(Dict[str, Any], cur_fold), do_print=False) + metrics.append(stats) + return metrics + + +def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]], + single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None: + if len(joined) != len(single): + raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}") + for j, s in zip(joined, single): + _update_one_weighted_average(j, s, cui2count) + + +def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]], + one: Dict[str, float], + cui2count: Dict[str, int]) -> None: + for k in one: + if k not in joined: + joined[k] = (0, 0) + prev_w, prev_val = joined[k] + new_w, new_val = cui2count[k], one[k] + total_w = prev_w + new_w + total_val = (prev_w * prev_val + new_w * new_val) / total_w + joined[k] = (total_w, total_val) + + +def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None: + if len(joined) != len(single): + raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}") + for j, s in zip(joined, single): + for k, v in s.items(): + j[k] = j.get(k, 0) + v + + +def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None: + for ex_type, ex_dict in cur_examples.items(): + if ex_type not in all_examples: + all_examples[ex_type] = {} + per_type_examples = all_examples[ex_type] + for ex_cui, cui_examples_list in ex_dict.items(): + if ex_cui not in per_type_examples: + per_type_examples[ex_cui] = [] + per_type_examples[ex_cui].extend(cui_examples_list) + + +def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]] + ) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]: + """The the mean of the provided metrics. + + Args: + metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics. + + Returns: + fps (dict): + False positives for each CUI. + fns (dict): + False negatives for each CUI. + tps (dict): + True positives for each CUI. + cui_prec (dict): + Precision for each CUI. + cui_rec (dict): + Recall for each CUI. + cui_f1 (dict): + F1 for each CUI. + cui_counts (dict): + Number of occurrence for each CUI. + examples (dict): + Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + """ + # additives + all_fps: Dict[str, int] = {} + all_fns: Dict[str, int] = {} + all_tps: Dict[str, int] = {} + # weighted-averages + all_cui_prec: Dict[str, Tuple[int, float]] = {} + all_cui_rec: Dict[str, Tuple[int, float]] = {} + all_cui_f1: Dict[str, Tuple[int, float]] = {} + # additive + all_cui_counts: Dict[str, int] = {} + # combined + all_additives = [ + all_fps, all_fns, all_tps, all_cui_counts + ] + all_weighted_averages = [ + all_cui_prec, all_cui_rec, all_cui_f1 + ] + # examples + all_examples: dict = {} + for current in metrics: + cur_wa: list = list(current[3:-2]) + cur_counts = current[-2] + _update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts) + # update ones that just need to be added up + cur_adds = list(current[:3]) + [cur_counts] + _update_all_add(all_additives, cur_adds) + # merge examples + cur_examples = current[-1] + _merge_examples(all_examples, cur_examples) + cui_prec: Dict[str, float] = {} + cui_rec: Dict[str, float] = {} + cui_f1: Dict[str, float] = {} + final_wa = [ + cui_prec, cui_rec, cui_f1 + ] + # just remove the weight / count + for df, d in zip(final_wa, all_weighted_averages): + for k, v in d.items(): + df[k] = v[1] # only the value, ingore the weight + return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2], + all_cui_counts, all_examples) + + +def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3, + split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED, *args, **kwargs) -> Tuple: + """Get the k-fold stats for the model with the specified data. + + First this will split the MCT export into `k` folds. You can do + this either per document or per-annotation. + + For each of the `k` folds, it will start from the base model, + train it with with the other `k-1` folds and record the metrics. + After that the base model state is restored before doing the next fold. + After all the folds have been done, the metrics are averaged. + + Args: + cat (CATLike): The model pack. + mct_export_data (MedCATTrainerExport): The MCT export. + k (int): The number of folds. Defaults to 3. + split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED. + *args: Arguments passed to the `CAT.train_supervised_raw` method. + **kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method. + + Returns: + Tuple: The averaged metrics. + """ + creator = get_fold_creator(mct_export_data, k, split_type=split_type) + folds = creator.create_folds() + per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs) + return get_metrics_mean(per_fold_metrics) diff --git a/medcat/stats/mctexport.py b/medcat/stats/mctexport.py new file mode 100644 index 000000000..54f5a4443 --- /dev/null +++ b/medcat/stats/mctexport.py @@ -0,0 +1,66 @@ +from typing import List, Iterator, Tuple, Any, Optional +from typing_extensions import TypedDict + + +class MedCATTrainerExportAnnotation(TypedDict): + start: int + end: int + cui: str + value: str + + +class MedCATTrainerExportDocument(TypedDict): + name: str + id: Any + last_modified: str + text: str + annotations: List[MedCATTrainerExportAnnotation] + + +class MedCATTrainerExportProject(TypedDict): + name: str + id: Any + cuis: str + tuis: Optional[str] + documents: List[MedCATTrainerExportDocument] + + +MedCATTrainerExportProjectInfo = Tuple[str, Any, str, Optional[str]] +"""The project name, project ID, CUIs str, and TUIs str""" + + +class MedCATTrainerExport(TypedDict): + projects: List[MedCATTrainerExportProject] + + +def iter_projects(export: MedCATTrainerExport) -> Iterator[MedCATTrainerExportProject]: + yield from export['projects'] + + +def iter_docs(export: MedCATTrainerExport + ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]: + for project in iter_projects(export): + info: MedCATTrainerExportProjectInfo = ( + project['name'], project['id'], project['cuis'], project.get('tuis', None) + ) + for doc in project['documents']: + yield info, doc + + +def iter_anns(export: MedCATTrainerExport + ) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation]]: + for proj_info, doc in iter_docs(export): + for ann in doc['annotations']: + yield proj_info, doc, ann + + +def count_all_annotations(export: MedCATTrainerExport) -> int: + return len(list(iter_anns(export))) + + +def count_all_docs(export: MedCATTrainerExport) -> int: + return len(list(iter_docs(export))) + + +def get_nr_of_annotations(doc: MedCATTrainerExportDocument) -> int: + return len(doc['annotations']) diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py index 610d4d2a1..e467e0519 100644 --- a/medcat/stats/stats.py +++ b/medcat/stats/stats.py @@ -60,6 +60,9 @@ def process_project(self, project: dict) -> None: # Add extra filter if set set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters) + project_name = cast(str, project.get('name')) + project_id = cast(str, project.get('id')) + documents = project["documents"] for dind, doc in tqdm( enumerate(documents), @@ -67,8 +70,7 @@ def process_project(self, project: dict) -> None: total=len(documents), leave=False, ): - self.process_document(cast(str, project.get('name')), - cast(str, project.get('id')), doc) + self.process_document(project_name, project_id, doc) def process_document(self, project_name: str, project_id: str, doc: dict) -> None: anns = self._get_doc_annotations(doc) diff --git a/medcat/tokenizers/meta_cat_tokenizers.py b/medcat/tokenizers/meta_cat_tokenizers.py index 7a4b07ac0..93d8b51ed 100644 --- a/medcat/tokenizers/meta_cat_tokenizers.py +++ b/medcat/tokenizers/meta_cat_tokenizers.py @@ -1,3 +1,4 @@ +import logging import os from abc import ABC, abstractmethod from typing import List, Dict, Optional, Union, overload @@ -26,7 +27,7 @@ def save(self, dir_path: str) -> None: ... @classmethod @abstractmethod - def load(cls, dir_path: str, **kwargs) -> Tokenizer: ... + def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> Tokenizer: ... @abstractmethod def get_size(self) -> int: ... @@ -112,7 +113,7 @@ def save(self, dir_path: str) -> None: self.hf_tokenizers.save_model(dir_path, prefix=self.name) @classmethod - def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBPE": + def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBPE": tokenizer = cls() vocab_file = os.path.join(dir_path, f'{tokenizer.name}-vocab.json') merges_file = os.path.join(dir_path, f'{tokenizer.name}-merges.txt') @@ -186,10 +187,14 @@ def save(self, dir_path: str) -> None: self.hf_tokenizers.save_pretrained(path) @classmethod - def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperBERT": + def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "TokenizerWrapperBERT": tokenizer = cls() path = os.path.join(dir_path, cls.name) - tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs) + try: + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs) + except Exception as e: + logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant)) + tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant) return tokenizer diff --git a/medcat/utils/cdb_state.py b/medcat/utils/cdb_state.py new file mode 100644 index 000000000..794a40109 --- /dev/null +++ b/medcat/utils/cdb_state.py @@ -0,0 +1,179 @@ +import logging +import contextlib +from typing import Dict, TypedDict, Set, List, cast +import numpy as np +import tempfile +import dill + +from copy import deepcopy + + + +logger = logging.getLogger(__name__) # separate logger from the package-level one + + +CDBState = TypedDict( + 'CDBState', + { + 'name2cuis': Dict[str, List[str]], + 'snames': Set[str], + 'cui2names': Dict[str, Set[str]], + 'cui2snames': Dict[str, Set[str]], + 'cui2context_vectors': Dict[str, Dict[str, np.ndarray]], + 'cui2count_train': Dict[str, int], + 'name_isupper': Dict, + 'vocab': Dict[str, int], + }) +"""CDB State. + +This is a dictionary of the parts of the CDB that change during +(supervised) training. It can be used to store and restore the +state of a CDB after modifying it. + +Currently, the following fields are saved: + - name2cuis + - snames + - cui2names + - cui2snames + - cui2context_vectors + - cui2count_train + - name_isupper + - vocab +""" + + +def copy_cdb_state(cdb) -> CDBState: + """Creates a (deep) copy of the CDB state. + + Grabs the fields that correspond to the state, + creates deep copies, and returns the copies. + + Args: + cdb: The CDB from which to grab the state. + + Returns: + CDBState: The copied state. + """ + return cast(CDBState, { + k: deepcopy(getattr(cdb, k)) for k in CDBState.__annotations__ + }) + + +def save_cdb_state(cdb, file_path: str) -> None: + """Saves CDB state in a file. + + Currently uses `dill.dump` to save the relevant fields/values. + + Args: + cdb: The CDB from which to grab the state. + file_path (str): The file to dump the state. + """ + # NOTE: The difference is that we don't create a copy here. + # That is so that we don't have to occupy the memory for + # both copies + the_dict = { + k: getattr(cdb, k) for k in CDBState.__annotations__ + } + logger.debug("Saving CDB state on disk at: '%s'", file_path) + with open(file_path, 'wb') as f: + dill.dump(the_dict, f) + + +def apply_cdb_state(cdb, state: CDBState) -> None: + """Apply the specified state to the specified CDB. + + This overwrites the current state of the CDB with one provided. + + Args: + cdb: The CDB to apply the state to. + state (CDBState): The state to use. + """ + for k, v in state.items(): + setattr(cdb, k, v) + + +def load_and_apply_cdb_state(cdb, file_path: str) -> None: + """Delete current CDB state and apply CDB state from file. + + This first delets the current state of the CDB. + This is to save memory. The idea is that saving the staet + on disk will save on RAM usage. But it wouldn't really + work too well if upon load, two instances were still in + memory. + + Args: + cdb: The CDB to apply the state to. + file_path (str): The file where the state has been saved to. + """ + # clear existing data on CDB + # this is so that we don't occupy the memory for both the loaded + # and the on-CDB data + logger.debug("Clearing CDB state in memory") + for k in CDBState.__annotations__: + val = getattr(cdb, k) + setattr(cdb, k, None) + del val + logger.debug("Loading CDB state from disk from '%s'", file_path) + with open(file_path, 'rb') as f: + data = dill.load(f) + for k in CDBState.__annotations__: + setattr(cdb, k, data[k]) + + +@contextlib.contextmanager +def captured_state_cdb(cdb, save_state_to_disk: bool = False): + """A context manager that captures and re-applies the initial CDB state. + + The context manager captures/copies the initial state of the CDB when entering. + It then allows the user to modify the state (i.e training). + Upon exit re-applies the initial CDB state. + + If RAM is an issue, it is recommended to use `save_state_to_disk`. + Otherwise the copy of the original state will be held in memory. + If saved on disk, a temporary file is used and removed afterwards. + + Args: + cdb: The CDB to use. + save_state_to_disk (bool): Whether to save state on disk or hold in in memory. + Defaults to False. + + Yields: + None + """ + if save_state_to_disk: + with on_disk_memory_capture(cdb): + yield + else: + with in_memory_state_capture(cdb): + yield + + +@contextlib.contextmanager +def in_memory_state_capture(cdb): + """Capture the CDB state in memory. + + Args: + cdb: The CDB to use. + + Yields: + None + """ + state = copy_cdb_state(cdb) + yield + apply_cdb_state(cdb, state) + + +@contextlib.contextmanager +def on_disk_memory_capture(cdb): + """Capture the CDB state in a temporary file. + + Args: + cdb: The CDB to use + + Yields: + None + """ + with tempfile.NamedTemporaryFile() as tf: + save_cdb_state(cdb, tf.name) + yield + load_and_apply_cdb_state(cdb, tf.name) diff --git a/medcat/utils/cdb_utils.py b/medcat/utils/cdb_utils.py index c473ddba4..fefaf1273 100644 --- a/medcat/utils/cdb_utils.py +++ b/medcat/utils/cdb_utils.py @@ -63,7 +63,7 @@ def merge_cdb(cdb1: CDB, ontologies.update(cdb2.addl_info['cui2ontologies'][cui]) if 'cui2description' in cdb2.addl_info: description = cdb2.addl_info['cui2description'][cui] - cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, + cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build) if cui in cdb1.cui2names: if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): diff --git a/medcat/utils/config_utils.py b/medcat/utils/config_utils.py index 1aafbf3f1..09989b258 100644 --- a/medcat/utils/config_utils.py +++ b/medcat/utils/config_utils.py @@ -1,15 +1,79 @@ from functools import partial -from typing import Callable +from typing import Callable, Optional, Protocol import logging +from pydantic import BaseModel + + +class WAFCarrier(Protocol): + + @property + def weighted_average_function(self) -> Callable[[float], int]: + pass logger = logging.getLogger(__name__) +def is_old_type_config_dict(d: dict) -> bool: + """Checks if the dict provided is an old style (jsonpickle) config. + + This checks for json-pickle specific keys such as py/object and py/state. + If both of those are keys somewhere within the 2 initial layers of the + nested dict, it's considered old style. + + Args: + d (dict): Loaded config. + + Returns: + bool: Whether it's an old style (jsonpickle) config. + """ + # all 2nd level keys + all_keys = set(sub_key for key in d for sub_key in (d[key] if isinstance(d[key], dict) else [key])) + # add 1st level keys + all_keys.update(d.keys()) + # is old if py/object and py/state somewhere in keys + return set(('py/object', 'py/state')) <= all_keys + + +def fix_waf_lambda(carrier: WAFCarrier) -> None: + weighted_average_function = carrier.weighted_average_function # type: ignore + if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "": + # the following type ignoring is for mypy because it is unable to detect the signature + carrier.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore + + +# NOTE: This method is a hacky workaround. The type ignores are because I cannot +# import config here since it would produce a circular import +def ensure_backward_compatibility(config: BaseModel, workers: Callable[[], int]) -> None: + # Hacky way of supporting old CDBs + if hasattr(config.linking, 'weighted_average_function'): # type: ignore + fix_waf_lambda(config.linking) # type: ignore + if config.general.workers is None: # type: ignore + config.general.workers = workers() # type: ignore + disabled_comps = config.general.spacy_disabled_components # type: ignore + if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps: + config.general.spacy_disabled_components.append('lemmatizer') # type: ignore + + +def get_and_del_weighted_average_from_config(config: BaseModel) -> Optional[Callable[[int], float]]: + if not hasattr(config, 'linking'): + return None + linking = config.linking + if not hasattr(linking, 'weighted_average_function'): + return None + waf = linking.weighted_average_function + delattr(linking, 'weighted_average_function') + return waf + + def weighted_average(step: int, factor: float) -> float: return max(0.1, 1 - (step ** 2 * factor)) +def default_weighted_average(step: int) -> float: + return weighted_average(step, factor=0.0004) + + def attempt_fix_weighted_average_function(waf: Callable[[int], float] ) -> Callable[[int], float]: """Attempf fix weighted_average_function. diff --git a/medcat/utils/decorators.py b/medcat/utils/decorators.py index a98922360..ca473774b 100644 --- a/medcat/utils/decorators.py +++ b/medcat/utils/decorators.py @@ -1,14 +1,30 @@ import warnings import functools -from typing import Callable +from typing import Callable, Tuple -def deprecated(message: str) -> Callable: +def _format_version(ver: Tuple[int, int, int]) -> str: + return ".".join(str(v) for v in ver) + + +def deprecated(message: str, depr_version: Tuple[int, int, int], removal_version: Tuple[int, int, int]) -> Callable: + """Deprecate a method. + + Args: + message (str): The deprecation message. + depr_version (Tuple[int, int, int]): The first version of MedCAT where this was deprecated. + removal_version (Tuple[int, int, int]): The first version of MedCAT where this will be removed. + + Returns: + Callable: _description_ + """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapped(*args, **kwargs) -> Callable: warnings.simplefilter("always", DeprecationWarning) warnings.warn("Function {} has been deprecated.{}".format(func.__name__, " " + message if message else "")) + warnings.warn(f"The above function was deprecated in v{_format_version(depr_version)} " + f"and will be removed in v{removal_version}") warnings.simplefilter("default", DeprecationWarning) return func(*args, **kwargs) return wrapped diff --git a/medcat/utils/meta_cat/data_utils.py b/medcat/utils/meta_cat/data_utils.py index 5d2060ca7..c4dc5f9c2 100644 --- a/medcat/utils/meta_cat/data_utils.py +++ b/medcat/utils/meta_cat/data_utils.py @@ -1,5 +1,8 @@ from typing import Dict, Optional, Tuple, Iterable, List from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase +import logging + +logger = logging.getLogger(__name__) def prepare_from_json(data: Dict, @@ -23,8 +26,6 @@ def prepare_from_json(data: Dict, Size of context to get from the right of the concept tokenizer (TokenizerWrapperBase): Something to split text into tokens for the LSTM/BERT/whatever meta models. - cui_filter (Optional[set]): - CUI filter if set. Defaults to None. replace_center (Optional[str]): If not None the center word (concept) will be replaced with whatever this is. prerequisites (Dict): @@ -33,6 +34,8 @@ def prepare_from_json(data: Dict, {'Experiencer': 'Patient'} - Take care that the CASE has to match whatever is in the data. Defaults to `{}`. lowercase (bool): Should the text be lowercased before tokenization. Defaults to True. + cui_filter (Optional[set]): + CUI filter if set. Defaults to None. Returns: out_data (dict): @@ -49,7 +52,8 @@ def prepare_from_json(data: Dict, if len(text) > 0: doc_text = tokenizer(text) - for ann in document.get('annotations', document.get('entities', {}).values()): # A hack to suport entities and annotations + for ann in document.get('annotations', document.get('entities', + {}).values()): # A hack to suport entities and annotations cui = ann['cui'] skip = False if 'meta_anns' in ann and prerequisites: @@ -61,21 +65,28 @@ def prepare_from_json(data: Dict, break if not skip and (cui_filter is None or not cui_filter or cui in cui_filter): - if ann.get('validated', True) and (not ann.get('deleted', False) and not ann.get('killed', False) - and not ann.get('irrelevant', False)): + if ann.get('validated', True) and ( + not ann.get('deleted', False) and not ann.get('killed', False) + and not ann.get('irrelevant', False)): start = ann['start'] end = ann['end'] - # Get the index of the center token - ind = 0 + # Updated implementation to extract all the tokens for the medical entity (rather than the one) + ctoken_idx = [] for ind, pair in enumerate(doc_text['offset_mapping']): - if start >= pair[0] and start < pair[1]: - break - - _start = max(0, ind - cntx_left) - _end = min(len(doc_text['input_ids']), ind + 1 + cntx_right) + if start <= pair[0] or start <= pair[1]: + if end <= pair[1]: + ctoken_idx.append(ind) + break + else: + ctoken_idx.append(ind) + + _start = max(0, ctoken_idx[0] - cntx_left) + _end = min(len(doc_text['input_ids']), ctoken_idx[-1] + 1 + cntx_right) + + cpos = cntx_left + min(0, ind - cntx_left) + cpos_new = [x - _start for x in ctoken_idx] tkns = doc_text['input_ids'][_start:_end] - cpos = cntx_left + min(0, ind-cntx_left) if replace_center is not None: if lowercase: @@ -87,19 +98,19 @@ def prepare_from_json(data: Dict, e_ind = p_ind ln = e_ind - s_ind - tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos+ln+1:] + tkns = tkns[:cpos] + tokenizer(replace_center)['input_ids'] + tkns[cpos + ln + 1:] # Backward compatibility if meta_anns is a list vs dict in the new approach meta_anns = [] if 'meta_anns' in ann: - meta_anns = ann['meta_anns'].values() if type(ann['meta_anns']) is dict else ann['meta_anns'] + meta_anns = ann['meta_anns'].values() if isinstance(ann['meta_anns'],dict) else ann['meta_anns'] # If the annotation is validated for meta_ann in meta_anns: name = meta_ann['name'] value = meta_ann['value'] - sample = [tkns, cpos, value] + sample = [tkns, cpos_new, value] if name in out_data: out_data[name].append(sample) @@ -108,7 +119,41 @@ def prepare_from_json(data: Dict, return out_data -def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None) -> Tuple: +def prepare_for_oversampled_data(data: List, + tokenizer: TokenizerWrapperBase) -> List: + """Convert the data from a json format into a CSV-like format for training. This function is not very efficient (the one + working with spacy documents as part of the meta_cat.pipe method is much better). If your dataset is > 1M documents think + about rewriting this function - but would be strange to have more than 1M manually annotated documents. + + Args: + data (List): + Oversampled data expected in the following format: + [[['text','of','the','document'], [index of medical entity], "label" ], + ['text','of','the','document'], [index of medical entity], "label" ]] + tokenizer (TokenizerWrapperBase): + Something to split text into tokens for the LSTM/BERT/whatever meta models. + + Returns: + data_sampled (list): + The processed data in the format that can be merged with the output from prepare_from_json. + [[<[tokens]>, [index of medical entity], "label" ], + <[tokens]>, [index of medical entity], "label" ]] + """ + + data_sampled = [] + for sample in data: + # Checking if the input is already tokenized + if isinstance(sample[0][0], str): + doc_text = tokenizer(sample[0]) + data_sampled.append([doc_text[0]['input_ids'], sample[1], sample[2]]) + else: + data_sampled.append([sample[0], sample[1], sample[2]]) + + return data_sampled + + +def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None, + category_undersample=None) -> Tuple: """Converts the category values in the data outputed by `prepare_from_json` into integere values. @@ -117,10 +162,14 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict Output of `prepare_from_json`. existing_category_value2id(Optional[Dict]): Map from category_value to id (old/existing). + category_undersample: + Name of class that should be used to undersample the data (for 2 phase learning) Returns: dict: - New data with integeres inplace of strings for categry values. + New underesampled data (for 2 phase learning) with integers inplace of strings for category values + dict: + New data with integers inplace of strings for category values. dict: Map rom category value to ID for all categories in the data. """ @@ -131,6 +180,23 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict category_value2id = {} category_values = set([x[2] for x in data]) + # Ensuring that each label has data and checking for class imbalance + + label_data = {key: 0 for key in category_value2id} + for i in range(len(data)): + if data[i][2] in category_value2id: + label_data[data[i][2]] = label_data[data[i][2]] + 1 + + # If a label has no data, changing the mapping + if 0 in label_data.values(): + category_value2id_: Dict = {} + keys_ls = [key for key, value in category_value2id.items() if value != 0] + for k in keys_ls: + category_value2id_[k] = len(category_value2id_) + + logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping:", category_value2id_) + category_value2id = category_value2id_ + for c in category_values: if c not in category_value2id: category_value2id[c] = len(category_value2id) @@ -139,30 +205,39 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict for i in range(len(data)): data[i][2] = category_value2id[data[i][2]] - return data, category_value2id + # Creating dict with labels and its number of samples + label_data_ = {v: 0 for v in category_value2id.values()} + for i in range(len(data)): + if data[i][2] in category_value2id.values(): + label_data_[data[i][2]] = label_data_[data[i][2]] + 1 + # Undersampling data + if category_undersample is None or category_undersample == '': + min_label = min(label_data_.values()) + else: + if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys(): + min_label = label_data_[category_value2id[category_undersample]] + else: + min_label = label_data_[category_undersample] -class Span(object): - def __init__(self, start_char: str, end_char: str, id_: str) -> None: - self._ = Empty() - self.start_char = start_char - self.end_char = end_char - self._.id = id_ # type: ignore - self._.meta_anns = None # type: ignore + data_undersampled = [] + label_data_counter = {v: 0 for v in category_value2id.values()} + for sample in data: + if label_data_counter[sample[-1]] < min_label: + data_undersampled.append(sample) + label_data_counter[sample[-1]] += 1 -class Doc(object): - def __init__(self, text: str, id_: str) -> None: - self._ = Empty() - self._.share_tokens = None # type: ignore - self.ents: List = [] - # We do not have overlapps at this stage - self._ents = self.ents - self.text = text - self.id = id_ + label_data = {v: 0 for v in category_value2id.values()} + for i in range(len(data_undersampled)): + if data_undersampled[i][2] in category_value2id.values(): + label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 + logger.info(f"Updated label_data: {label_data}") + + return data_undersampled, data, category_value2id -def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]: +def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable: """Creates a generator of fake spacy documents, used for running meta_cat pipe separately from main cat pipeline. @@ -173,7 +248,7 @@ def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]: Map from document id to text of that document. Yields: - Doc: spacy like documents that can be feed into meta_cat.pipe. + Generator: Generator of spacy like documents that can be feed into meta_cat.pipe. """ for id_ in data.keys(): ents = data[id_]['entities'].values() @@ -187,3 +262,23 @@ def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable[Doc]: class Empty(object): def __init__(self) -> None: pass + + +class Span(object): + def __init__(self, start_char: str, end_char: str, id_: str) -> None: + self._ = Empty() + self.start_char = start_char + self.end_char = end_char + self._.id = id_ # type: ignore + self._.meta_anns = None # type: ignore + + +class Doc(object): + def __init__(self, text: str, id_: str) -> None: + self._ = Empty() + self._.share_tokens = None # type: ignore + self.ents: List = [] + # We do not have overlapps at this stage + self._ents = self.ents + self.text = text + self.id = id_ diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py index 9f75fa69a..79cedb9f3 100644 --- a/medcat/utils/meta_cat/ml_utils.py +++ b/medcat/utils/meta_cat/ml_utils.py @@ -2,18 +2,22 @@ import random import math import torch +import torch.nn.functional as F import numpy as np import pandas as pd import torch.optim as optim -from typing import List, Optional, Tuple, Any, Dict +from typing import List, Optional, Tuple, Any, Dict, Union from torch import nn from scipy.special import softmax from medcat.config_meta_cat import ConfigMetaCAT from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase from sklearn.metrics import classification_report, precision_recall_fscore_support, confusion_matrix +from sklearn.model_selection import train_test_split +from sklearn.utils.class_weight import compute_class_weight +from transformers import AdamW, get_linear_schedule_with_warmup -import logging +import logging logger = logging.getLogger(__name__) @@ -47,9 +51,13 @@ def create_batch_piped_data(data: List[Tuple[List[int], int, Optional[int]]], Same as data, but subsetted and as a tensor cpos (): Center positions for the data + attention_mask: + Indicating padding mask for the data + y: + class label of the data """ max_seq_len = max([len(x[0]) for x in data]) - x = [x[0][0:max_seq_len] + [pad_id]*max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]] + x = [x[0][0:max_seq_len] + [pad_id] * max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]] cpos = [x[1] for x in data[start_ind:end_ind]] y = None if len(data[0]) == 3: @@ -57,9 +65,9 @@ def create_batch_piped_data(data: List[Tuple[List[int], int, Optional[int]]], y = torch.tensor([x[2] for x in data[start_ind:end_ind]], dtype=torch.long).to(device) x = torch.tensor(x, dtype=torch.long).to(device) - cpos = torch.tensor(cpos, dtype=torch.long).to(device) - - return x, cpos, y + # cpos = torch.tensor(cpos, dtype=torch.long).to(device) + attention_masks = (x != 0).type(torch.int) + return x, cpos, attention_masks, y def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]], @@ -94,8 +102,10 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]], with torch.no_grad(): for i in range(num_batches): - x, cpos, _ = create_batch_piped_data(data, i*batch_size, (i+1)*batch_size, device=device, pad_id=pad_id) - logits = model(x, cpos, ignore_cpos=ignore_cpos) + x, cpos, attention_masks, _ = create_batch_piped_data(data, i * batch_size, (i + 1) * batch_size, + device=device, pad_id=pad_id) + + logits = model(x, center_positions=cpos, attention_mask=attention_masks, ignore_cpos=ignore_cpos) all_logits.append(logits.detach().cpu().numpy()) predictions = [] @@ -111,7 +121,7 @@ def predict(model: nn.Module, data: List[Tuple[List[int], int, Optional[int]]], def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> Tuple: - """Shuffle and randomply split data + """Shuffle and randomly split data Args: data (List): The data. @@ -124,9 +134,14 @@ def split_list_train_test(data: List, test_size: float, shuffle: bool = True) -> if shuffle: random.shuffle(data) - test_ind = int(len(data) * test_size) - test_data = data[:test_ind] - train_data = data[test_ind:] + X_features = [x[:-1] for x in data] + y_labels = [x[-1] for x in data] + + X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size, + random_state=42) + + train_data = [x + [y] for x, y in zip(X_train, y_train)] + test_data = [x + [y] for x, y in zip(X_test, y_test)] return train_data, test_data @@ -142,12 +157,25 @@ def print_report(epoch: int, running_loss: List, all_logits: List, y: Any, name: name (str): The name of the report. Defaults to Train. """ if all_logits: - logger.info('Epoch: %d %s %s', epoch, "*"*50, name) + logger.info('Epoch: %d %s %s', epoch, "*" * 50, name) logger.info(classification_report(y, np.argmax(np.concatenate(all_logits, axis=0), axis=1))) +class FocalLoss(nn.Module): + def __init__(self, alpha=None, gamma=2): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, inputs, targets): + ce_loss = F.cross_entropy(inputs, targets, reduction='none') + pt = torch.exp(-ce_loss) + loss = (self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss).mean() + return loss + + def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_path: Optional[str] = None) -> Dict: - """Trains a LSTM model (for now) with autocheckpoints + """Trains a LSTM model and BERT with autocheckpoints Args: model (nn.Module): The model @@ -162,18 +190,82 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa Exception: If auto-save is enabled but no save dir path is provided. """ # Get train/test from data - train_data, test_data = split_list_train_test(data, test_size=config.train['test_size'], shuffle=config.train['shuffle_data']) - device = torch.device(config.general['device']) # Create a torch device + train_data, test_data = split_list_train_test(data, test_size=config.train['test_size'], + shuffle=config.train['shuffle_data']) + device = torch.device(config.general['device']) # Create a torch device class_weights = config.train['class_weights'] - if class_weights is not None: - class_weights = torch.FloatTensor(class_weights).to(device) - criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss + + if class_weights is None: + if config.train['compute_class_weights'] is True: + y_ = [x[2] for x in train_data] + class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_) + config.train['class_weights'] = class_weights + logger.info(f"Class weights computed: {class_weights}") + + class_weights = torch.FloatTensor(class_weights).to(device) + if config.train['loss_funct'] == 'cross_entropy': + criterion: Union[FocalLoss, nn.CrossEntropyLoss] = nn.CrossEntropyLoss( + weight=class_weights) + elif config.train['loss_funct'] == 'focal_loss': + criterion = FocalLoss(alpha=class_weights, gamma=config.train['gamma']) + + else: + logger.warning("Class weights not provided and compute_class_weights parameter is set to False. No class weights used for training.") + if config.train['loss_funct'] == 'cross_entropy': + criterion = nn.CrossEntropyLoss() + elif config.train['loss_funct'] == 'focal_loss': + criterion = FocalLoss(gamma=config.train['gamma']) else: - criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss + class_weights = torch.FloatTensor(class_weights).to(device) + if config.train['loss_funct'] == 'cross_entropy': + criterion = nn.CrossEntropyLoss( + weight=class_weights) + elif config.train['loss_funct'] == 'focal_loss': + criterion = FocalLoss(alpha=class_weights, gamma=config.train['gamma']) + parameters = filter(lambda p: p.requires_grad, model.parameters()) - optimizer = optim.Adam(parameters, lr=config.train['lr']) - model.to(device) # Move the model to device + + def initialize_model(classifier, data_, batch_size_, lr_, epochs=4): + """Initialize the Classifier, the optimizer and the learning rate scheduler. + + Args: + classifier (nn.Module): + The model to be trained + data_ (List): + The data + batch_size_: + Batch size + lr_: + Learning rate for training + epochs: + Number of training iterations + + Returns: + classifier: + model + optimizer_: + optimizer + scheduler_: + scheduler + """ + + # Create the optimizer + optimizer_ = AdamW(classifier.parameters(), + lr=lr_, # Default learning rate + eps=1e-8, # Default epsilon value + weight_decay=1e-5 + ) + + # Total number of training steps + total_steps = int((len(data_) / batch_size_) * epochs) + logger.info('Total steps for optimizer: {}'.format(total_steps)) + + # Set up the learning rate scheduler + scheduler_ = get_linear_schedule_with_warmup(optimizer_, + num_warmup_steps=0, # Default value + num_training_steps=total_steps) + return classifier, optimizer_, scheduler_ batch_size = config.train['batch_size'] batch_size_eval = config.general['batch_size_eval'] @@ -182,6 +274,13 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa ignore_cpos = config.model['ignore_cpos'] num_batches = math.ceil(len(train_data) / batch_size) num_batches_test = math.ceil(len(test_data) / batch_size_eval) + optimizer = optim.Adam(parameters, lr=config.train['lr'], weight_decay=1e-5) + if config.model.model_architecture_config is not None: + if config.model.model_architecture_config['lr_scheduler'] is True: + model, optimizer, scheduler = initialize_model(model, train_data, batch_size, config.train['lr'], + epochs=nepochs) + + model.to(device) # Move the model to device # Can be pre-calculated for the whole dataset y_test = [x[2] for x in test_data] @@ -193,8 +292,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa all_logits = [] model.train() for i in range(num_batches): - x, cpos, y = create_batch_piped_data(train_data, i*batch_size, (i+1)*batch_size, device=device, pad_id=pad_id) - logits = model(x, center_positions=cpos, ignore_cpos=ignore_cpos) + model.zero_grad() + + x, cpos, attention_masks, y = create_batch_piped_data(train_data, i * batch_size, (i + 1) * batch_size, + device=device, pad_id=pad_id) + logits = model(x, attention_mask=attention_masks, center_positions=cpos, ignore_cpos=ignore_cpos) loss = criterion(logits, y) loss.backward() # Track loss and logits @@ -202,8 +304,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa all_logits.append(logits.detach().cpu().numpy()) parameters = filter(lambda p: p.requires_grad, model.parameters()) - nn.utils.clip_grad_norm_(parameters, 0.25) + nn.utils.clip_grad_norm_(parameters, 0.15) optimizer.step() + if config.model.model_architecture_config is not None: + if config.model.model_architecture_config['lr_scheduler'] is True: + scheduler.step() all_logits_test = [] running_loss_test = [] @@ -211,8 +316,10 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa with torch.no_grad(): for i in range(num_batches_test): - x, cpos, y = create_batch_piped_data(test_data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, pad_id=pad_id) - logits = model(x, cpos, ignore_cpos=ignore_cpos) + x, cpos, attention_masks, y = create_batch_piped_data(test_data, i * batch_size_eval, + (i + 1) * batch_size_eval, device=device, + pad_id=pad_id) + logits = model(x, attention_mask=attention_masks, center_positions=cpos, ignore_cpos=ignore_cpos) # Track loss and logits running_loss_test.append(loss.item()) @@ -221,12 +328,20 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa print_report(epoch, running_loss, all_logits, y=y_train, name='Train') print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test') - _report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), output_dict=True) + _report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), + output_dict=True) if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \ winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]: - report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), output_dict=True) + report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), + output_dict=True) + cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true') + report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1), + output_dict=True) + + winner_report['confusion_matrix'] = cm winner_report['report'] = report + winner_report['report_train'] = report_train winner_report['epoch'] = epoch # Save if needed @@ -237,8 +352,11 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa else: path = os.path.join(save_dir_path, 'model.dat') torch.save(model.state_dict(), path) - logger.info("\n##### Model saved to %s at epoch: %d and %s/%s: %s #####\n", path, epoch, config.train['metric']['base'], - config.train['metric']['score'], winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]) + logger.info("\n##### Model saved to %s at epoch: %d and %s/%s: %s #####\n", path, epoch, + config.train['metric']['base'], + config.train['metric']['score'], + winner_report['report'][config.train['metric']['base']][ + config.train['metric']['score']]) return winner_report @@ -255,7 +373,7 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T Returns: Dict: Results (precision, recall, f1, examples, confusion matrix) """ - device = torch.device(config.general['device']) # Create a torch device + device = torch.device(config.general['device']) # Create a torch device batch_size_eval = config.general['batch_size_eval'] pad_id = config.model['padding_idx'] ignore_cpos = config.model['ignore_cpos'] @@ -263,9 +381,9 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T if class_weights is not None: class_weights = torch.FloatTensor(class_weights).to(device) - criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss + criterion = nn.CrossEntropyLoss(weight=class_weights) # Set the criterion to Cross Entropy Loss else: - criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss + criterion = nn.CrossEntropyLoss() # Set the criterion to Cross Entropy Loss y_eval = [x[2] for x in data] num_batches = math.ceil(len(data) / batch_size_eval) @@ -276,8 +394,11 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T with torch.no_grad(): for i in range(num_batches): - x, cpos, y = create_batch_piped_data(data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, pad_id=pad_id) - logits = model(x, cpos, ignore_cpos=ignore_cpos) + x, cpos, attention_masks, y = create_batch_piped_data(data, i * batch_size_eval, (i + 1) * batch_size_eval, + device=device, pad_id=pad_id) + + logits = model(x, center_positions=cpos, attention_mask=attention_masks, ignore_cpos=ignore_cpos) + loss = criterion(logits, y) # Track loss and logits @@ -290,24 +411,27 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T predictions = np.argmax(np.concatenate(all_logits, axis=0), axis=1) precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average) - labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x:x[1])] + labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])] confusion = pd.DataFrame( - data=confusion_matrix(y_eval, predictions,), + data=confusion_matrix(y_eval, predictions, ), columns=["true " + label for label in labels], index=["predicted " + label for label in labels], ) - examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}} id2category_value = {v: k for k, v in config.general['category_value2id'].items()} for i, p in enumerate(predictions): y = id2category_value[y_eval[i]] p = id2category_value[p] c = data[i][1] + if isinstance(c,list): + c = c[-1] + tkns = data[i][0] assert tokenizer.hf_tokenizers is not None - text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<"+ tokenizer.hf_tokenizers.decode(tkns[c:c+1]).strip() + ">> " + \ - tokenizer.hf_tokenizers.decode(tkns[c+1:]) + text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<" + tokenizer.hf_tokenizers.decode( + tkns[c:c + 1]).strip() + ">> " + \ + tokenizer.hf_tokenizers.decode(tkns[c + 1:]) info = "Predicted: {}, True: {}".format(p, y) if p != y: # We made a mistake diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index c28a2c6ed..70e235316 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -1,11 +1,11 @@ import torch from collections import OrderedDict -from typing import Optional, Any, List +from typing import Optional, Any, List, Iterable from torch import nn, Tensor -from torch.nn import CrossEntropyLoss -from transformers import BertPreTrainedModel, BertModel, BertConfig -from transformers.modeling_outputs import TokenClassifierOutput +from transformers import BertModel, AutoConfig from medcat.meta_cat import ConfigMetaCAT +import logging +logger = logging.getLogger(__name__) class LSTM(nn.Module): @@ -24,7 +24,7 @@ def __init__(self, embeddings: Optional[Tensor], config: ConfigMetaCAT) -> None: # Disable training for the embeddings - IMPORTANT self.embeddings.weight.requires_grad = config.model['emb_grad'] - # Create the RNN cell - devide + # Create the RNN cell - devide self.rnn = nn.LSTM(input_size=config.model['input_size'], hidden_size=config.model['hidden_size'] // config.model['num_directions'], num_layers=config.model['num_layers'], @@ -47,10 +47,11 @@ def forward(self, mask = attention_mask # Embed the input: from id -> vec - x = self.embeddings(x) # x.shape = batch_size x sequence_length x emb_size + x = self.embeddings(x) # x.shape = batch_size x sequence_length x emb_size # Tell RNN to ignore padding and set the batch_first to True - x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int().view(-1).cpu(), batch_first=True, enforce_sorted=False) + x = nn.utils.rnn.pack_padded_sequence(x, mask.sum(1).int().view(-1).cpu(), batch_first=True, + enforce_sorted=False) # Run 'x' through the RNN x, hidden = self.rnn(x) @@ -59,16 +60,22 @@ def forward(self, x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) # Get what we need - row_indices = torch.arange(0, x.size(0)).long() + # row_indices = torch.arange(0, x.size(0)).long() # If this is True we will always take the last state and not CPOS if ignore_cpos: x = hidden[0] x = x.view(self.config.model['num_layers'], self.config.model['num_directions'], -1, - self.config.model['hidden_size']//self.config.model['num_directions']) + self.config.model['hidden_size'] // self.config.model['num_directions']) x = x[-1, :, :, :].permute(1, 2, 0).reshape(-1, self.config.model['hidden_size']) else: - x = x[row_indices, center_positions, :] + x_all = [] + for i, indices in enumerate(center_positions): + this_hidden = x[i, indices, :] + to_append, _ = torch.max(this_hidden, dim=0) + x_all.append(to_append) + + x = torch.stack(x_all) # Push x through the fc network and add dropout x = self.d1(x) @@ -77,34 +84,61 @@ def forward(self, return x -class BertForMetaAnnotation(BertPreTrainedModel): - +class BertForMetaAnnotation(nn.Module): _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore - def __init__(self, config: BertConfig) -> None: - super().__init__(config) - self.num_labels = config.num_labels - - self.bert = BertModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + def __init__(self, config): + super(BertForMetaAnnotation, self).__init__() + _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) + if config.model['input_size'] != _bertconfig.hidden_size: + logger.warning(f"\nInput size for {config.model.model_variant} model should be {_bertconfig.hidden_size}, provided input size is {config.model['input_size']} Input size changed to {_bertconfig.hidden_size}") - self.init_weights() # type: ignore + bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) + self.config = config + self.config.use_return_dict = False + self.bert = bert + self.num_labels = config.model["nclasses"] + for param in self.bert.parameters(): + param.requires_grad = not config.model.model_freeze_layers + + hidden_size_2 = int(config.model.hidden_size / 2) + # dropout layer + self.dropout = nn.Dropout(config.model.dropout) + # relu activation function + self.relu = nn.ReLU() + # dense layer 1 + self.fc1 = nn.Linear(_bertconfig.hidden_size*2, config.model.hidden_size) + # dense layer 2 + self.fc2 = nn.Linear(config.model.hidden_size, hidden_size_2) + # dense layer 3 + self.fc3 = nn.Linear(hidden_size_2, hidden_size_2) + # dense layer 3 (Output layer) + model_arch_config = config.model.model_architecture_config + if model_arch_config is not None: + if model_arch_config['fc2'] is True or model_arch_config['fc3'] is True: + self.fc4 = nn.Linear(hidden_size_2, self.num_labels) + else: + self.fc4 = nn.Linear(config.model.hidden_size, self.num_labels) + else: + self.fc4 = nn.Linear(hidden_size_2, self.num_labels) + # softmax activation function + self.softmax = nn.LogSoftmax(dim=1) def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - center_positions: Optional[Any] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> TokenClassifierOutput: + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + center_positions: Iterable[Any] = [], + ignore_cpos: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ): """labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. @@ -119,47 +153,59 @@ def forward( labels (Optional[torch.LongTensor]): Labels. Defaults to None. center_positions (Optional[Any]): Cennter positions. Defaults to None. output_attentions (Optional[bool]): Output attentions. Defaults to None. + ignore_cpos: If center positions are to be ignored. output_hidden_states (Optional[bool]): Output hidden states. Defaults to None. return_dict (Optional[bool]): Whether to return a dict. Defaults to None. Returns: TokenClassifierOutput: The token classifier output. """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore - outputs = self.bert( # type: ignore + outputs = self.bert( # type: ignore input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + attention_mask=attention_mask, output_hidden_states=True ) - sequence_output = outputs[0] # (batch_size, sequence_length, hidden_size) - - row_indices = torch.arange(0, sequence_output.size(0)).long() - sequence_output = sequence_output[row_indices, center_positions, :] + x_all = [] + for i, indices in enumerate(center_positions): + this_hidden: torch.Tensor = outputs.last_hidden_state[i, indices, :] + to_append, _ = torch.max(this_hidden, dim=0) + x_all.append(to_append) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) + x = torch.stack(x_all) - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + pooled_output = outputs[1] + x = torch.cat((x, pooled_output), dim=1) - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + # fc1 + x = self.dropout(x) + x = self.fc1(x) + x = self.relu(x) + + if self.config.model.model_architecture_config is not None: + if self.config.model.model_architecture_config['fc2'] is True: + # fc2 + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + + if self.config.model.model_architecture_config['fc3'] is True: + # fc3 + x = self.fc3(x) + x = self.relu(x) + x = self.dropout(x) + else: + # fc2 + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + + # fc3 + x = self.fc3(x) + x = self.relu(x) + x = self.dropout(x) + + # output layer + x = self.fc4(x) + return x diff --git a/medcat/utils/ner/__init__.py b/medcat/utils/ner/__init__.py index 2657c7df7..5d296dc3a 100644 --- a/medcat/utils/ner/__init__.py +++ b/medcat/utils/ner/__init__.py @@ -1,2 +1,2 @@ from .metrics import metrics -from .helpers import deid_text, make_or_update_cdb +from .helpers import make_or_update_cdb diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py index d71b52004..688bb1ea6 100644 --- a/medcat/utils/ner/deid.py +++ b/medcat/utils/ner/deid.py @@ -40,7 +40,7 @@ from medcat.cat import CAT from medcat.utils.ner.model import NerModel -from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text +from medcat.utils.ner.helpers import replace_entities_in_text logger = logging.getLogger(__name__) @@ -69,6 +69,12 @@ def train(self, json_path: Union[str, list, None], def deid_text(self, text: str, redact: bool = False) -> str: """Deidentify text and potentially redact information. + De-identified text. + If redaction is enabled, identifiable entities will be + replaced with starts (e.g `*****`). + Otherwise, the replacement will be the CUI or in other words, + the type of information that was hidden (e.g [PATIENT]). + Args: text (str): The text to deidentify. redact (bool): Whether to redact the information. @@ -76,8 +82,8 @@ def deid_text(self, text: str, redact: bool = False) -> str: Returns: str: The deidentified text. """ - self.cat.get_entities - return deid_text(self.cat, text, redact=redact) + entities = self.cat.get_entities(text)['entities'] + return replace_entities_in_text(text, entities, self.cat.cdb.get_name, redact=redact) def deid_multi_texts(self, texts: Union[Iterable[str], Iterable[Tuple]], diff --git a/medcat/utils/ner/helpers.py b/medcat/utils/ner/helpers.py index 65ae7050c..bea1e45ca 100644 --- a/medcat/utils/ner/helpers.py +++ b/medcat/utils/ner/helpers.py @@ -3,35 +3,6 @@ from medcat.utils.data_utils import count_annotations from medcat.cdb import CDB -from medcat.utils.decorators import deprecated - - -# For now, we will keep this method separate from the above class -# This is so that we wouldn't need to create a thorwaway object -# when calling the method from .helpers where it used to be. -# After the deprecated method in .helpers is removed, we can -# move this to a proper class method. -def _deid_text(cat, text: str, redact: bool = False) -> str: - """De-identify text. - - De-identified text. - If redaction is enabled, identifiable entities will be - replaced with starts (e.g `*****`). - Otherwise, the replacement will be the CUI or in other words, - the type of information that was hidden (e.g [PATIENT]). - - - Args: - cat (CAT): The CAT object to use for deid. - text (str): The input document. - redact (bool): Whether to redact. Defaults to False. - - Returns: - str: The de-identified document. - """ - entities = cat.get_entities(text)['entities'] - return replace_entities_in_text(text, entities, cat.cdb.get_name, redact=redact) - def replace_entities_in_text(text: str, entities: Dict, @@ -45,13 +16,6 @@ def replace_entities_in_text(text: str, return new_text -@deprecated("API now allows creating a DeId model (medcat.utils.ner.deid.DeIdModel). " - "It aims to simplify the usage of DeId models. " - "The use of this model is encouraged over the use of this method.") -def deid_text(*args, **kwargs) -> str: - return _deid_text(*args, **kwargs) - - def make_or_update_cdb(json_path: str, cdb: Optional[CDB] = None, min_count: int = 0) -> CDB: """Creates a new CDB or updates an existing one with new diff --git a/medcat/utils/saving/coding.py b/medcat/utils/saving/coding.py index 81a8420aa..89f9c0651 100644 --- a/medcat/utils/saving/coding.py +++ b/medcat/utils/saving/coding.py @@ -1,6 +1,7 @@ from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable import json +import re @runtime_checkable @@ -35,6 +36,7 @@ def try_encode(self, obj: object) -> Any: SET_IDENTIFIER = '==SET==' +PATTERN_IDENTIFIER = "==PATTERN==" class SetEncoder(PartEncoder): @@ -79,10 +81,34 @@ def try_decode(self, dct: dict) -> Union[dict, set]: return dct +class PatternEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, re.Pattern): + return {PATTERN_IDENTIFIER: obj.pattern} + raise UnsuitableObject() + + +class PatternDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, re.Pattern]: + """Decode re.Patttern from input dicts. + + Args: + dct (dict): The input dict + + Returns: + Union[dict, set]: The original dict if this was not a serialized pattern, the pattern otherwise + """ + if PATTERN_IDENTIFIER in dct: + return re.compile(dct[PATTERN_IDENTIFIER]) + return dct + + PostProcessor = Callable[[Any], None] # CDB -> None -DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ] -DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ] +DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, PatternEncoder] +DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, PatternDecoder] LOADING_POSTPROCESSORS: List[PostProcessor] = [] @@ -133,6 +159,8 @@ def object_hook(self, dct: dict) -> Any: def def_inst(cls) -> 'CustomDelegatingDecoder': if cls._def_inst is None: cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) + elif len(cls._def_inst._delegates) < len(DEFAULT_DECODERS): + cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) return cls._def_inst diff --git a/medcat/utils/saving/envsnapshot.py b/medcat/utils/saving/envsnapshot.py new file mode 100644 index 000000000..262c48410 --- /dev/null +++ b/medcat/utils/saving/envsnapshot.py @@ -0,0 +1,73 @@ +from typing import List, Dict, Any, Set + +import os +import re +import pkg_resources +import platform + + +ENV_SNAPSHOT_FILE_NAME = "environment_snapshot.json" + +INSTALL_REQUIRES_FILE_PATH = os.path.join(os.path.dirname(__file__), + "..", "..", "..", + "install_requires.txt") +# NOTE: The install_requires.txt file is copied into the wheel during build +# so that it can be included in the distributed package. +# However, that means it's 1 folder closer to this file since it'll now +# be in the root of the package rather than the root of the project. +INSTALL_REQUIRES_FILE_PATH_PIP = os.path.join(os.path.dirname(__file__), + "..", "..", + "install_requires.txt") + + +def get_direct_dependencies() -> Set[str]: + """Get the set of direct dependeny names. + + The current implementation reads install_requires.txt for dependenceies, + removes comments, whitespace, quotes; removes the versions and returns + the names as a set. + + Returns: + Set[str]: The set of direct dependeny names. + """ + req_file = INSTALL_REQUIRES_FILE_PATH + if not os.path.exists(req_file): + # When pip-installed. See note above near constant definiation + req_file = INSTALL_REQUIRES_FILE_PATH_PIP + with open(req_file) as f: + # read every line, strip quotes and comments + dep_lines = [line.split("#")[0].replace("'", "").replace('"', "").strip() for line in f.readlines()] + # remove comment-only (or empty) lines + deps = [dep for dep in dep_lines if dep] + return set(re.split("[@<=>~]", dep)[0].strip() for dep in deps) + + +def get_installed_packages() -> List[List[str]]: + """Get the installed packages and their versions. + + Returns: + List[List[str]]: List of lists. Each item contains of a dependency name and version. + """ + direct_deps = get_direct_dependencies() + installed_packages = [] + for package in pkg_resources.working_set: + if package.project_name not in direct_deps: + continue + installed_packages.append([package.project_name, package.version]) + return installed_packages + + +def get_environment_info() -> Dict[str, Any]: + """Get the current environment information. + + This includes dependency versions, the OS, the CPU architecture and the python version. + + Returns: + Dict[str, Any]: _description_ + """ + return { + "dependencies": get_installed_packages(), + "os": platform.platform(), + "cpu_architecture": platform.machine(), + "python_version": platform.python_version() + } diff --git a/setup.py b/setup.py index cfb824727..549e7c091 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,18 @@ import setuptools +import shutil with open("./README.md", "r") as fh: long_description = fh.read() +# make a copy of install requirements so that it gets distributed with the wheel +shutil.copy('install_requires.txt', 'medcat/install_requires.txt') + +with open("install_requires.txt") as f: + # read every line, strip quotes and comments + dep_lines = [l.split("#")[0].replace("'", "").replace('"', "").strip() for l in f.readlines()] + # remove comment-only (or empty) lines + install_requires = [dep for dep in dep_lines if dep] + setuptools.setup( name="medcat", @@ -17,31 +27,9 @@ packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets', 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction', 'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'], - install_requires=[ - 'numpy>=1.22.0,<1.26.0', # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy - 'pandas>=1.4.2', # first to support 3.11 - 'gensim>=4.3.0,<5.0.0', # 5.3.0 is first to support 3.11; avoid major version bump - 'spacy>=3.6.0,<4.0.0', # Some later model packs (e.g HPO) are made with 3.6.0 spacy model; avoid major version bump - 'scipy~=1.9.2', # 1.9.2 is first to support 3.11 - 'transformers>=4.34.0,<5.0.0', # avoid major version bump - 'accelerate>=0.23.0', # required by Trainer class in de-id - 'torch>=1.13.0,<3.0.0', # 1.13 is first to support 3.11; 2.1.2 has been compatible, but avoid major 3.0.0 for now - 'tqdm>=4.27', - 'scikit-learn>=1.1.3,<2.0.0', # 1.1.3 is first to supporrt 3.11; avoid major version bump - 'dill>=0.3.6,<1.0.0', # stuff saved in 0.3.6/0.3.7 is not always compatible with 0.3.4/0.3.5; avoid major bump - 'datasets>=2.2.2,<3.0.0', # avoid major bump - 'jsonpickle>=2.0.0', # allow later versions, tested with 3.0.0 - 'psutil>=5.8.0', - # 0.70.12 uses older version of dill (i.e less than 0.3.5) which is required for datasets - 'multiprocess~=0.70.12', # 0.70.14 seemed to work just fine - 'aiofiles>=0.8.0', # allow later versions, tested with 22.1.0 - 'ipywidgets>=7.6.5', # allow later versions, tested with 0.8.0 - 'xxhash>=3.0.0', # allow later versions, tested with 3.1.0 - 'blis>=0.7.5', # allow later versions, tested with 0.7.9 - 'click>=8.0.4', # allow later versions, tested with 8.1.3 - 'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes - "humanfriendly~=10.0", # for human readable file / RAM sizes - ], + install_requires=install_requires, + include_package_data=True, + package_data={"medcat": ["install_requires.txt"]}, classifiers=[ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..7fbc9f3b2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,25 @@ +from typing import Callable, Tuple + +from medcat.utils import decorators + + +class DeprecatedMethodCallException(ValueError): + + def __init__(self, func: Callable, msg: str, + depr_version: Tuple[int, int, int], + removal_version: Tuple[int, int, int]) -> None: + super().__init__(f"A deprecated method {func.__name__} was called. Deprecation message:\n{msg}\n" + f"The method was deprecated in v{depr_version} and is scheduled for " + f"removal in v{removal_version}") + + +def deprecation_exception_raiser(message: str, depr_version: Tuple[int, int, int], + removal_version: Tuple[int, int, int]): + def decorator(func: Callable) -> Callable: + def wrapper(*_, **__): + raise DeprecatedMethodCallException(func, message, depr_version, removal_version) + return wrapper + return decorator + + +decorators.deprecated = deprecation_exception_raiser diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py deleted file mode 100644 index 9e2fc2d72..000000000 --- a/tests/archive_tests/test_cdb_maker_archive.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -import unittest -import numpy as np -from medcat.cdb import CDB -from medcat.cdb_maker import CDBMaker -from medcat.config import Config -from medcat.preprocessing.cleaners import prepare_name - - -class CdbMakerArchiveTests(unittest.TestCase): - - def setUp(self): - self.config = Config() - self.config.general['log_level'] = logging.DEBUG - self.maker = CDBMaker(self.config) - - # Building a new CDB from two files (full_build) - csvs = ['../examples/cdb.csv', '../examples/cdb_2.csv'] - self.cdb = self.maker.prepare_csvs(csvs, full_build=True) - - def test_prepare_csvs(self): - assert len(self.cdb.cui2names) == 3 - assert len(self.cdb.cui2snames) == 3 - assert len(self.cdb.name2cuis) == 5 - assert len(self.cdb.cui2tags) == 3 - assert len(self.cdb.cui2preferred_name) == 2 - assert len(self.cdb.cui2context_vectors) == 3 - assert len(self.cdb.cui2count_train) == 3 - assert self.cdb.name2cuis2status['virus']['C0000039'] == 'P' - assert self.cdb.cui2type_ids['C0000039'] == {'T234', 'T109', 'T123'} - assert self.cdb.addl_info['cui2original_names']['C0000039'] == {'Virus', 'Virus K', 'Virus M', 'Virus Z'} - assert self.cdb.addl_info['cui2description']['C0000039'].startswith("Synthetic") - - def test_name_addition(self): - self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config), name_status='P', full_build=True) - assert self.cdb.addl_info['cui2original_names']['C0000239'] == {'MY: new,-_! Name.', 'Second csv'} - assert 'my:newname.' in self.cdb.name2cuis - assert 'my:new' in self.cdb.snames - assert 'my:newname.' in self.cdb.name2cuis2status - assert self.cdb.name2cuis2status['my:newname.'] == {'C0000239': 'P'} - - def test_name_removal(self): - self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config)) - # Run again to make sure it does not break anything - self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.get_spacy_nlp(), {}, self.config)) - assert len(self.cdb.name2cuis) == 5 - assert 'my:newname.' not in self.cdb.name2cuis2status - - def test_filtering(self): - cuis_to_keep = {'C0000039'} # Because of transition 2 will be kept - self.cdb.filter_by_cui(cuis_to_keep=cuis_to_keep) - assert len(self.cdb.cui2names) == 2 - assert len(self.cdb.name2cuis) == 4 - assert len(self.cdb.snames) == 4 - - def test_vector_addition(self): - self.cdb.reset_training() - np.random.seed(11) - cuis = list(self.cdb.cui2names.keys()) - for i in range(2): - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=False) - - assert self.cdb.cui2count_train['C0000139'] == 2 - assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300 - - - def test_negative(self): - cuis = list(self.cdb.cui2names.keys()) - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=True) - - assert self.cdb.cui2count_train['C0000139'] == 2 - assert self.cdb.cui2context_vectors['C0000139']['long'].shape[0] == 300 - - def test_save_and_load(self): - self.cdb.save("./tmp_cdb.dat") - cdb2 = CDB.load('./tmp_cdb.dat') - # Check a random thing - assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7] - - def test_training_import(self): - cdb2 = CDB.load('./tmp_cdb.dat') - self.cdb.reset_training() - cdb2.reset_training() - np.random.seed(11) - cuis = list(self.cdb.cui2names.keys()) - for i in range(2): - for cui in cuis: - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - self.cdb.update_context_vector(cui, vectors, negative=False) - - cdb2.import_training(cdb=self.cdb, overwrite=True) - assert cdb2.cui2context_vectors['C0000139']['long'][7] == self.cdb.cui2context_vectors['C0000139']['long'][7] - assert cdb2.cui2count_train['C0000139'] == self.cdb.cui2count_train['C0000139'] - - def test_concept_similarity(self): - cdb = CDB(config=self.config) - np.random.seed(11) - for i in range(500): - cui = "C" + str(i) - type_ids = {'T-' + str(i%10)} - cdb._add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(), - name_status='P', type_ids=type_ids, description='', full_build=True) - - vectors = {} - for cntx_type in self.config.linking['context_vector_sizes']: - vectors[cntx_type] = np.random.rand(300) - cdb.update_context_vector(cui, vectors, negative=False) - res = cdb.most_similar('C200', 'long', type_id_filter=['T-0'], min_cnt=1, topn=10, force_build=True) - assert len(res) == 10 - - def test_training_reset(self): - self.cdb.reset_training() - assert len(self.cdb.cui2context_vectors['C0']) == 0 - assert self.cdb.cui2count_train['C0'] == 0 diff --git a/tests/archive_tests/test_ner_archive.py b/tests/archive_tests/test_ner_archive.py deleted file mode 100644 index d41ccd0c7..000000000 --- a/tests/archive_tests/test_ner_archive.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -import unittest -import numpy as np -from timeit import default_timer as timer -from medcat.cdb import CDB -from medcat.preprocessing.tokenizers import spacy_split_all -from medcat.ner.vocab_based_ner import NER -from medcat.preprocessing.taggers import tag_skip_and_punct -from medcat.pipe import Pipe -from medcat.utils.normalizers import BasicSpellChecker -from medcat.vocab import Vocab -from medcat.preprocessing.cleaners import prepare_name -from medcat.linking.vector_context_model import ContextModel -from medcat.linking.context_based_linker import Linker -from medcat.config import Config - -from ..helper import VocabDownloader - - -class NerArchiveTests(unittest.TestCase): - - def setUp(self) -> None: - self.config = Config() - self.config.general['log_level'] = logging.INFO - cdb = CDB(config=self.config) - - self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config) - self.nlp.add_tagger(tagger=tag_skip_and_punct, - name='skip_and_punct', - additional_fields=['is_punct']) - - # Add a couple of names - cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config)) - # Check - #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}} - - downloader = VocabDownloader() - self.vocab_path = downloader.vocab_path - downloader.check_or_download() - - vocab = Vocab.load(self.vocab_path) - # Make the pipeline - self.nlp = Pipe(tokenizer=spacy_split_all, config=self.config) - self.nlp.add_tagger(tagger=tag_skip_and_punct, - name='skip_and_punct', - additional_fields=['is_punct']) - spell_checker = BasicSpellChecker(cdb_vocab=cdb.vocab, config=self.config, data_vocab=vocab) - self.nlp.add_token_normalizer(spell_checker=spell_checker, config=self.config) - ner = NER(cdb, self.config) - self.nlp.add_ner(ner) - - # Add Linker - link = Linker(cdb, vocab, self.config) - self.nlp.add_linker(link) - - self.text = "CDB - I was running and then Movar Virus attacked and CDb" - - def tearDown(self) -> None: - self.nlp.destroy() - - def test_limits_for_tokens_and_uppercase(self): - self.config.ner['max_skip_tokens'] = 1 - self.config.ner['upper_case_limit_len'] = 4 - self.config.linking['disamb_length_limit'] = 2 - - d = self.nlp(self.text) - - assert len(d._.ents) == 2 - assert d._.ents[0]._.link_candidates[0] == 'S-229004' - - def test_change_limit_for_skip(self): - self.config.ner['max_skip_tokens'] = 3 - d = self.nlp(self.text) - assert len(d._.ents) == 3 - - def test_change_limit_for_upper_case(self): - self.config.ner['upper_case_limit_len'] = 3 - d = self.nlp(self.text) - assert len(d._.ents) == 4 - - def test_check_name_length_limit(self): - self.config.ner['min_name_len'] = 4 - d = self.nlp(self.text) - assert len(d._.ents) == 2 - - def test_speed(self): - text = "CDB - I was running and then Movar Virus attacked and CDb" - text = text * 300 - self.config.general['spell_check'] = True - start = timer() - for i in range(50): - d = self.nlp(text) - end = timer() - print("Time: ", end - start) - - def test_without_spell_check(self): - # Now without spell check - self.config.general['spell_check'] = False - start = timer() - for i in range(50): - d = self.nlp(self.text) - end = timer() - print("Time: ", end - start) - - - def test_for_linker(self): - self.config = Config() - self.config.general['log_level'] = logging.DEBUG - cdb = CDB(config=self.config) - - # Add a couple of names - cdb.add_names(cui='S-229004', names=prepare_name('Movar', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229004', names=prepare_name('Movar viruses', self.nlp, {}, self.config)) - cdb.add_names(cui='S-229005', names=prepare_name('CDB', self.nlp, {}, self.config)) - cdb.add_names(cui='S-2290045', names=prepare_name('Movar', self.nlp, {}, self.config)) - # Check - #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}, 'S-2290045': {'movar'}} - - cuis = list(cdb.cui2names.keys()) - for cui in cuis[0:50]: - vectors = {'short': np.random.rand(300), - 'long': np.random.rand(300), - 'medium': np.random.rand(300) - } - cdb.update_context_vector(cui, vectors, negative=False) - - d = self.nlp(self.text) - vocab = Vocab.load(self.vocab_path) - cm = ContextModel(cdb, vocab, self.config) - cm.train_using_negative_sampling('S-229004') - self.config.linking['train_count_threshold'] = 0 - - cm.train('S-229004', d._.ents[1], d) - - cm.similarity('S-229004', d._.ents[1], d) - - cm.disambiguate(['S-2290045', 'S-229004'], d._.ents[1], 'movar', d) diff --git a/tests/check_deprecations.py b/tests/check_deprecations.py new file mode 100644 index 000000000..4d10fba97 --- /dev/null +++ b/tests/check_deprecations.py @@ -0,0 +1,178 @@ +from typing import List, Dict, Optional, Tuple, Callable +import ast +import os +from sys import argv as sys_argv +from sys import exit as sys_exit +from medcat.utils.decorators import deprecated + + +def get_decorator_args(decorator: ast.expr, decorator_name: str) -> Tuple[Optional[List[str]], Optional[Dict[str, str]]]: + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id == decorator_name: + return decorator.args, {kw.arg: kw.value for kw in decorator.keywords} + return None, None + + +def is_decorated_with(node: ast.FunctionDef, decorator_name: str) -> Tuple[bool, List[str], Dict[str, str]]: + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + return True, [], {} + elif isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id == decorator_name: + args, kwargs = get_decorator_args(decorator, decorator_name) + return True, args, kwargs + return False, [], {} + + +class FunctionVisitor(ast.NodeVisitor): + def __init__(self, decorator_name: str): + self.decorator_name = decorator_name + self.decorated_functions: List[Dict[str, Optional[List[str]]]] = [] + self.context: List[str] = [] + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self.context.append(node.name) + is_decorated, args, kwargs = is_decorated_with(node, self.decorator_name) + if is_decorated: + self.decorated_functions.append({ + 'name': '.'.join(self.context), + 'args': args, + 'kwargs': kwargs + }) + self.generic_visit(node) + self.context.pop() + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.visit_FunctionDef(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self.context.append(node.name) + self.generic_visit(node) + self.context.pop() + + +def find_decorated_functions_in_file(filepath: str, decorator_name: str) -> List[Dict[str, Optional[List[str]]]]: + with open(filepath, "r") as source: + tree = ast.parse(source.read()) + + visitor = FunctionVisitor(decorator_name) + visitor.visit(tree) + return visitor.decorated_functions + + +def find_decorated_functions_in_codebase(codebase_path: str, decorator_name: str) -> Dict[str, List[Dict[str, Optional[List[str]]]]]: + decorated_functions: Dict[str, List[Dict[str, Optional[List[str]]]]] = {} + for root, _, files in os.walk(codebase_path): + for file in files: + if file.endswith(".py"): + filepath = os.path.join(root, file) + decorated_funcs = find_decorated_functions_in_file(filepath, decorator_name) + if decorated_funcs: + decorated_functions[filepath] = decorated_funcs + return decorated_functions + + +def extract_version_from_tuple(tuple_node: ast.Tuple) -> Tuple[int, int, int]: + """Extract constant values from an ast.Tuple node. + + Args: + tuple_node (ast.Tuple): The AST node representing the tuple. + + Raises: + ValueError: If the tuple contains unsuitable values. + + Returns: + Tuple[int, int, int]: The major, minor, and patch version. + """ + values = [] + for element in tuple_node.elts: + if isinstance(element, ast.Constant): + cur_value = element.value + else: + raise ValueError(f"Unsupported element type in tuple: {type(element)}") + values.append(cur_value) + if not isinstance(cur_value, int): + raise ValueError(f"Unknown type of value in version tuple: {type(cur_value)}: {cur_value}") + if len(values) != 3: + raise ValueError(f"Unexpected number of version elements ({len(values)}): {values}") + return tuple(values) + + +def get_deprecated_methods_that_should_have_been_removed(codebase_path: str, + decorator_name: str, + medcat_version: Tuple[int, int, int] + ) -> List[Tuple[str, str, Tuple[int, int, int]]]: + """Get deprecated methods that should have been removed. + + Args: + codebase_path (str): Path to codebase. + decorator_name (str): Name of decorator to check. + medcat_version (Tuple[int, int, int]): The current MedCAT version. + + Returns: + List[Tuple[str, str, Tuple[int, int, int]]]: + The list of file, method, and version in which the method should have been deprecated. + """ + decorated_functions = find_decorated_functions_in_codebase(codebase_path, decorator_name) + + should_be_removed = [] + for filepath, funcs in decorated_functions.items(): + for func in funcs: + func_name = func['name'] + args, kwargs = func['args'], func['kwargs'] + if 'removal_version' in kwargs: + rem_ver = kwargs['removal_version'] + else: + rem_ver = args[-1] + rem_ver = extract_version_from_tuple(rem_ver) + if rem_ver <= medcat_version: + should_be_removed.append((filepath, func_name, rem_ver)) + return should_be_removed + + +def _ver2str(ver: Tuple[int, int, int]) -> str: + maj, min, patch = ver + return f"v{maj}.{min}.{patch}" + + +def main(args: List[str] = sys_argv[1:], + deprecated_decorator: Callable[[], Callable] = deprecated): + decorator_name = deprecated_decorator.__name__ + pos_args = [arg for arg in args if not arg.startswith("-")] + codebase_path = 'medcat' if len(pos_args) <= 1 else pos_args[1] + print("arg0", repr(args[0])) + remove_ver_prefix = '--remove-prefix' in args + pure_ver = pos_args[0] + if remove_ver_prefix: + # remove v from (e.g) v1.12.0 + pure_ver = pure_ver[1:] + medcat_version = tuple(int(s) for s in pure_ver.split(".")) + compare_next_minor_release = '--next-version' in args + + # pad out medcat varesions + # NOTE: Mostly so that e.g (1, 12, 0) <= (1, 12, 0) would be True. + # Otherwise (1, 12, 0) <= (1, 12) would equate to False. + if len(medcat_version) < 3: + medcat_version = tuple(list(medcat_version) + [0,] * (3 - len(medcat_version))) + # NOTE: In main GHA workflow we know the current minor release + # but after that release has been done, we (generally, but not always!) + # want to start removing deprecated methods due to be removed before + # the next minor release. + if compare_next_minor_release: + l_ver = list(medcat_version) + l_ver[1] += 1 + medcat_version = tuple(l_ver) + + to_remove = get_deprecated_methods_that_should_have_been_removed(codebase_path, decorator_name, medcat_version) + + ver_descr = "next" if compare_next_minor_release else "current" + for filepath, func_name, rem_ver in to_remove: + print("SHOULD ALREADY BE REMOVED") + print(f"In file: {filepath}") + print(f" Method: {func_name}") + print(f" Scheduled for removal in: {_ver2str(rem_ver)} ({ver_descr} version: {_ver2str(medcat_version)})") + if to_remove: + print("Found issues - see above") + sys_exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/medmentions/make_cdb.py b/tests/medmentions/make_cdb.py deleted file mode 100644 index feb8629d2..000000000 --- a/tests/medmentions/make_cdb.py +++ /dev/null @@ -1,120 +0,0 @@ -from medcat.cdb_maker import CDBMaker -from medcat.config import Config, weighted_average -from functools import partial -import numpy as np -import logging - -from ..helper import VocabDownloader - - -config = Config() -config.general['log_level'] = logging.INFO -config.general['spacy_model'] = 'en_core_sci_lg' -maker = CDBMaker(config) - -# Building a new CDB from two files (full_build) -csvs = ['./tmp_medmentions.csv'] -cdb = maker.prepare_csvs(csvs, full_build=True) - -cdb.save("./tmp_cdb.dat") - - -from medcat.vocab import Vocab -from medcat.cdb import CDB -from medcat.cat import CAT - -downloader = VocabDownloader() -vocab_path = downloader.vocab_path -downloader.check_or_download() - -config = Config() -cdb = CDB.load("./tmp_cdb.dat", config=config) -vocab = Vocab.load(vocab_path) - -cdb.reset_training() - -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) -cat.config.ner['min_name_len'] = 3 -cat.config.ner['upper_case_limit_len'] = 3 -cat.config.linking['disamb_length_limit'] = 3 -cat.config.linking['filters'] = {'cuis': set()} -cat.config.linking['train_count_threshold'] = -1 -cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3} -cat.config.linking['context_vector_weights'] = {'xlong': 0, 'long': 0.4, 'medium': 0.4, 'short': 0.2} -cat.config.linking['weighted_average_function'] = partial(weighted_average, factor=0.0004) -cat.config.linking['similarity_threshold_type'] = 'dynamic' -cat.config.linking['similarity_threshold'] = 0.35 -cat.config.linking['calculate_dynamic_threshold'] = True - -cat.train(df.text.values, fine_tune=True) - - -cdb.config.general['spacy_disabled_components'] = ['ner', 'parser', 'vectors', 'textcat', - 'entity_linker', 'sentencizer', 'entity_ruler', 'merge_noun_chunks', - 'merge_entities', 'merge_subtokens'] - -%load_ext autoreload -%autoreload 2 - -# Train -_ = cat.train(open("./tmp_medmentions_text_only.txt", 'r'), fine_tune=False) - -_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=True, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0.1) -cdb.save("/home/ubuntu/data/umls/2020ab/cdb_trained_medmen.dat") - - -_ = cat.train_supervised("/home/ubuntu/data/medmentions/medmentions.json", reset_cui_count=False, nepochs=13, train_from_false_positives=True, print_stats=3, test_size=0) - -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) -cat.config.linking['similarity_threshold'] = 0.1 -cat.config.ner['min_name_len'] = 2 -cat.config.ner['upper_case_limit_len'] = 1 -cat.config.linking['train_count_threshold'] = -2 -cat.config.linking['filters']['cuis'] = set() -cat.config.linking['context_vector_sizes'] = {'xlong': 27, 'long': 18, 'medium': 9, 'short': 3} -cat.config.linking['context_vector_weights'] = {'xlong': 0.1, 'long': 0.4, 'medium': 0.4, 'short': 0.1} -cat.config.linking['similarity_threshold_type'] = 'static' - -cat.config.linking['similarity_threshold_type'] = 'dynamic' -cat.config.linking['similarity_threshold'] = 0.35 -cat.config.linking['calculate_dynamic_threshold'] = True - - -# Print some stats -_ = cat._print_stats(data) - -#Epoch: 0, Prec: 0.4331506351144245, Rec: 0.5207520064957372, F1: 0.47292889758643175 -#p: 0.421 r: 0.507 f1: 0.460 - - -# Remove all names that are numbers -for name in list(cdb.name2cuis.keys()): - if name.replace(".", '').replace("~", '').replace(",", '').replace(":", '').replace("-", '').isnumeric(): - del cdb.name2cuis[name] - print(name) - - -for name in list(cdb.name2cuis.keys()): - if len(name) < 7 and (not name.isalpha()) and len(re.sub("[^A-Za-z]*", '', name)) < 2: - del cdb.name2cuis[name] - print(name) - - - - -# RUN SUPER -cdb = CDB.load("./tmp_cdb.dat") -vocab = Vocab.load(vocab_path) -cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) - - -# Train supervised -cdb.reset_cui_count() -cat.config.ner['uppe_case_limit_len'] = 1 -cat.config.ner['min_name_len'] = 1 -data_path = "./tmp_medmentions.json" -_ = cat.train_supervised(data_path, use_cui_doc_limit=True, nepochs=30, devalue_others=True, test_size=0.2) - - -cdb = maker.prepare_csvs(csv_paths=csvs) -cdb.save("/home/ubuntu/data/umls/2020ab/cdb_vbg.dat") diff --git a/tests/medmentions/prepare_data.py b/tests/medmentions/prepare_data.py deleted file mode 100644 index 6e1bfdf2e..000000000 --- a/tests/medmentions/prepare_data.py +++ /dev/null @@ -1,7 +0,0 @@ -from medcat.utils.medmentions import original2concept_csv -from medcat.utils.medmentions import original2json -from medcat.utils.medmentions import original2pure_text - -_ = original2json("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.json') -_ = original2concept_csv("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions.csv') -original2pure_text("../../examples/medmentions/medmentions.txt", '../../examples/medmentions/tmp_medmentions_text_only.txt') diff --git a/tests/resources/jsonpickle_config.json b/tests/resources/jsonpickle_config.json new file mode 100644 index 000000000..784f933ce --- /dev/null +++ b/tests/resources/jsonpickle_config.json @@ -0,0 +1,274 @@ +{ + "version": { + "py/object": "medcat.config.VersionInfo", + "py/state": { + "__dict__": { + "history": ["0c0de303b6dc0020"], + "meta_cats": {}, + "cdb_info": {}, + "performance": { + "ner": {}, + "meta": {} + }, + "description": "No description", + "id": null, + "last_modified": null, + "location": null, + "ontology": null, + "medcat_version": null + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "cdb_maker": { + "py/object": "medcat.config.CDBMaker", + "py/state": { + "__dict__": { + "name_versions": [ + "LOWER", + "CLEAN" + ], + "multi_separator": "|", + "remove_parenthesis": 5, + "min_letters_required": 2 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "annotation_output": { + "py/object": "medcat.config.AnnotationOutput", + "py/state": { + "__dict__": { + "doc_extended_info": false, + "context_left": -1, + "context_right": -1, + "lowercase_context": true, + "include_text_in_output": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "general": { + "py/object": "medcat.config.General", + "py/state": { + "__dict__": { + "spacy_disabled_components": [ + "ner", + "parser", + "vectors", + "textcat", + "entity_linker", + "sentencizer", + "entity_ruler", + "merge_noun_chunks", + "merge_entities", + "merge_subtokens" + ], + "checkpoint": { + "py/object": "medcat.config.CheckPoint", + "py/state": { + "__dict__": { + "output_dir": "checkpoints", + "steps": null, + "max_to_keep": 1 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "log_level": 20, + "log_format": "%(levelname)s:%(name)s: %(message)s", + "log_path": "./medcat.log", + "spacy_model": "en_core_web_lg", + "separator": "~", + "spell_check": true, + "diacritics": false, + "spell_check_deep": false, + "spell_check_len_limit": 7, + "show_nested_entities": false, + "full_unlink": false, + "workers": 7, + "make_pretty_labels": null, + "map_cui_to_group": false + }, + "__fields_set__": { + "py/set": [ + "spacy_model" + ] + }, + "__private_attribute_values__": {} + } + }, + "preprocessing": { + "py/object": "medcat.config.Preprocessing", + "py/state": { + "__dict__": { + "words_to_skip": { + "py/set": [ + "nos" + ] + }, + "keep_punct": { + "py/set": [ + ".", + ":" + ] + }, + "do_not_normalize": { + "py/set": [ + "VBD", + "VBP", + "VBN", + "JJR", + "JJS", + "VBG" + ] + }, + "skip_stopwords": false, + "min_len_normalize": 5, + "stopwords": { + "py/set": [ + "three", + "two", + "one" + ] + }, + "max_document_length": 1000000 + }, + "__fields_set__": { + "py/set": [ + "stopwords" + ] + }, + "__private_attribute_values__": {} + } + }, + "ner": { + "py/object": "medcat.config.Ner", + "py/state": { + "__dict__": { + "min_name_len": 3, + "max_skip_tokens": 2, + "check_upper_case_names": false, + "upper_case_limit_len": 4, + "try_reverse_word_order": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "linking": { + "py/object": "medcat.config.Linking", + "py/state": { + "__dict__": { + "optim": { + "type": "linear", + "base_lr": 1, + "min_lr": 0.00005 + }, + "context_vector_sizes": { + "xlong": 27, + "long": 18, + "medium": 9, + "short": 3 + }, + "context_vector_weights": { + "xlong": 0.1, + "long": 0.4, + "medium": 0.4, + "short": 0.1 + }, + "filters": { + "py/object": "medcat.config.LinkingFilters", + "py/state": { + "__dict__": { + "cuis": { + "py/set": [] + }, + "cuis_exclude": { + "py/set": [] + } + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": true, + "random_replacement_unsupervised": 0.8, + "disamb_length_limit": 3, + "filter_before_disamb": false, + "train_count_threshold": 1, + "always_calculate_similarity": false, + "weighted_average_function": { + "py/object": "medcat.config._DefPartial", + "fun": { + "py/reduce": [ + { + "py/type": "functools.partial" + }, + { + "py/tuple": [ + { + "py/function": "medcat.utils.config_utils.weighted_average" + } + ] + }, + { + "py/tuple": [ + { + "py/function": "medcat.utils.config_utils.weighted_average" + }, + { + "py/tuple": [] + }, + { + "factor": 0.0004 + }, + {} + ] + } + ] + } + }, + "calculate_dynamic_threshold": false, + "similarity_threshold_type": "static", + "similarity_threshold": 0.25, + "negative_probability": 0.5, + "negative_ignore_punct_and_num": true, + "prefer_primary_name": 0.35, + "prefer_frequent_concepts": 0.35, + "subsample_after": 30000, + "devalue_linked_concepts": false, + "context_ignore_center_tokens": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "word_skipper": { + "py/object": "re.Pattern", + "pattern": "^(nos)$" + }, + "punct_checker": { + "py/object": "re.Pattern", + "pattern": "[^a-z0-9]+" + }, + "hash": null + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_meta_cat_config.json b/tests/resources/jsonpickle_meta_cat_config.json new file mode 100644 index 000000000..4da001c6c --- /dev/null +++ b/tests/resources/jsonpickle_meta_cat_config.json @@ -0,0 +1,89 @@ +{ + "general": { + "py/object": "medcat.config_meta_cat.General", + "py/state": { + "__dict__": { + "device": "cpu", + "disable_component_lock": false, + "seed": -100, + "description": "No description", + "category_name": null, + "category_value2id": {}, + "vocab_size": null, + "lowercase": true, + "cntx_left": 15, + "cntx_right": 10, + "replace_center": null, + "batch_size_eval": 5000, + "annotate_overlapping": false, + "tokenizer_name": "bbpe", + "save_and_reuse_tokens": false, + "pipe_batch_size_in_chars": 20000000, + "span_group": null + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "model": { + "py/object": "medcat.config_meta_cat.Model", + "py/state": { + "__dict__": { + "model_name": "lstm", + "model_variant": "bert-base-uncased", + "model_freeze_layers": true, + "num_layers": 2, + "input_size": 300, + "hidden_size": 300, + "dropout": 0.5, + "phase_number": 0, + "category_undersample": "", + "model_architecture_config": { + "fc2": true, + "fc3": false, + "lr_scheduler": true + }, + "num_directions": 2, + "nclasses": 2, + "padding_idx": -1, + "emb_grad": true, + "ignore_cpos": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": { + "py/object": "medcat.config_meta_cat.Train", + "py/state": { + "__dict__": { + "batch_size": 100, + "nepochs": 50, + "lr": 0.001, + "test_size": 0.1, + "shuffle_data": true, + "class_weights": null, + "compute_class_weights": false, + "score_average": "weighted", + "prerequisites": {}, + "cui_filter": null, + "auto_save_model": true, + "last_train_on": null, + "metric": { + "base": "weighted avg", + "score": "f1-score" + }, + "loss_funct": "cross_entropy", + "gamma": 2 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_rel_cat_config.json b/tests/resources/jsonpickle_rel_cat_config.json new file mode 100644 index 000000000..411caaa52 --- /dev/null +++ b/tests/resources/jsonpickle_rel_cat_config.json @@ -0,0 +1,91 @@ +{ + "general": { + "py/object": "medcat.config_rel_cat.General", + "py/state": { + "__dict__": { + "device": "cpu", + "relation_type_filter_pairs": [], + "vocab_size": null, + "lowercase": true, + "cntx_left": 15, + "cntx_right": 15, + "window_size": 300, + "mct_export_max_non_rel_sample_size": 200, + "mct_export_create_addl_rels": false, + "tokenizer_name": "bert", + "model_name": "bert-base-uncased", + "log_level": 20, + "max_seq_length": 512, + "tokenizer_special_tokens": false, + "annotation_schema_tag_ids": [], + "labels2idx": {}, + "idx2labels": {}, + "pin_memory": true, + "seed": 13, + "task": "train" + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "model": { + "py/object": "medcat.config_rel_cat.Model", + "py/state": { + "__dict__": { + "input_size": 300, + "hidden_size": 768, + "hidden_layers": 3, + "model_size": 5120, + "dropout": 0.2, + "num_directions": 2, + "padding_idx": -1, + "emb_grad": true, + "ignore_cpos": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": { + "py/object": "medcat.config_rel_cat.Train", + "py/state": { + "__dict__": { + "nclasses": 2, + "batch_size": 25, + "nepochs": 1, + "lr": 100000, + "adam_epsilon": 0.0001, + "test_size": 0.2, + "gradient_acc_steps": 1, + "multistep_milestones": [ + 2, + 4, + 6, + 8, + 12, + 15, + 18, + 20, + 22, + 24, + 26, + 30 + ], + "multistep_lr_gamma": 0.8, + "max_grad_norm": 1, + "shuffle_data": true, + "class_weights": null, + "score_average": "weighted", + "auto_save_model": true + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_tner_config.json b/tests/resources/jsonpickle_tner_config.json new file mode 100644 index 000000000..eb3639453 --- /dev/null +++ b/tests/resources/jsonpickle_tner_config.json @@ -0,0 +1,23 @@ +{ + "general": { + "py/object": "medcat.config_transformers_ner.General", + "py/state": { + "__dict__": { + "name": "deid", + "model_name": "roberta-base", + "seed": 13, + "description": "No description", + "pipe_batch_size_in_chars": -100, + "ner_aggregation_strategy": "simple", + "chunking_overlap_window": 5, + "test_size": 0.2, + "last_train_on": null, + "verbose_metrics": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json new file mode 100644 index 000000000..79f1a0ac4 --- /dev/null +++ b/tests/resources/medcat_trainer_export_FAKE_CONCEPTS.json @@ -0,0 +1,84 @@ +{ + "projects": [ + { + "cuis": "", + "tuis": "", + "name": "TEST-PROJ", + "id": "PROJ_FAKE", + "documents": [ + { + "name": "fake_doc_0", + "id": 100, + "last_modified": "-1", + "text": "This virus is called virus M and was read from the second CSV we could find.", + "annotations": [ + { + "cui": "C0000039", + "start": 5, + "end": 10, + "value": "virus" + }, + { + "cui": "C0000139", + "start": 21, + "end": 28, + "value": "virus M" + }, + { + "cui": "C0000239", + "start": 51, + "end": 62, + "value": "second CSV" + } + ] + }, + { + "name": "fake_doc_1", + "id": 101, + "last_modified": "-1", + "text": "We found a virus. Turned out it was virus M. This was the second CSV we looked at.", + "annotations": [ + { + "cui": "C0000039", + "start": 11, + "end": 16, + "value": "virus" + }, + { + "cui": "C0000139", + "start": 36, + "end": 43, + "value": "virus M" + }, + { + "cui": "C0000239", + "start": 58, + "end": 69, + "value": "second CSV" + } + ] + }, + { + "name": "fake_doc_2", + "id": 102, + "last_modified": "-1", + "text": "We opened second CSV and found virus M to be the culprit.", + "annotations": [ + { + "cui": "C0000239", + "start": 10, + "end": 21, + "value": "second CSV" + }, + { + "cui": "C0000139", + "start": 31, + "end": 38, + "value": "virus M" + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/webapp/webapp/demo/__init__.py b/tests/stats/__init__.py similarity index 100% rename from webapp/webapp/demo/__init__.py rename to tests/stats/__init__.py diff --git a/tests/stats/helpers.py b/tests/stats/helpers.py new file mode 100644 index 000000000..80771b11c --- /dev/null +++ b/tests/stats/helpers.py @@ -0,0 +1,17 @@ +from pydantic import create_model_from_typeddict + +from medcat.stats.mctexport import MedCATTrainerExport + + +MCTExportPydanticModel = create_model_from_typeddict(MedCATTrainerExport) + + +def nullify_doc_names_proj_ids(export: MedCATTrainerExport) -> MedCATTrainerExport: + return {'projects': [ + { + 'name': project['name'], + 'documents': sorted([ + {k: v if k != 'name' else '' for k, v in doc.items()} for doc in project['documents'] + ], key=lambda doc: doc['id']) + } for project in export['projects'] + ]} diff --git a/tests/stats/test_kfold.py b/tests/stats/test_kfold.py new file mode 100644 index 000000000..87dcdd454 --- /dev/null +++ b/tests/stats/test_kfold.py @@ -0,0 +1,298 @@ +import os +import json +from typing import Dict, Union, Optional +from copy import deepcopy + +from medcat.stats import kfold +from medcat.cat import CAT +from pydantic.error_wrappers import ValidationError as PydanticValidationError + +import unittest + +from .helpers import MCTExportPydanticModel, nullify_doc_names_proj_ids + + +class MCTExportTests(unittest.TestCase): + EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export.json") + + @classmethod + def setUpClass(cls) -> None: + with open(cls.EXPORT_PATH) as f: + cls.mct_export = json.load(f) + + def assertIsMCTExport(self, obj): + try: + model = MCTExportPydanticModel(**obj) + except PydanticValidationError as e: + raise AssertionError("Not n MCT export") from e + self.assertIsInstance(model, MCTExportPydanticModel) + + +class KFoldCreatorTests(MCTExportTests): + K = 3 + SPLIT_TYPE = kfold.SplitType.DOCUMENTS + + + def setUp(self) -> None: + self.creator = kfold.get_fold_creator(self.mct_export, self.K, split_type=self.SPLIT_TYPE) + self.folds = self.creator.create_folds() + + def test_folding_does_not_modify_initial_export(self): + with open(self.EXPORT_PATH) as f: + export_copy = json.load(f) + self.assertEqual(export_copy, self.mct_export) + + def test_mct_export_has_correct_format(self): + self.assertIsMCTExport(self.mct_export) + + def test_folds_have_docs(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(kfold.count_all_docs(fold), 0) + + def test_folds_have_anns(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(kfold.count_all_annotations(fold), 0) + + def test_folds_are_mct_exports(self): + for nr, fold in enumerate(self.folds): + with self.subTest(f"Fold-{nr}"): + self.assertIsMCTExport(fold) + + def test_gets_correct_number_of_folds(self): + self.assertEqual(len(self.folds), self.K) + + def test_folds_keep_all_docs(self): + total_docs = 0 + for fold in self.folds: + docs = kfold.count_all_docs(fold) + total_docs += docs + count_all_once = kfold.count_all_docs(self.mct_export) + if self.SPLIT_TYPE is kfold.SplitType.ANNOTATIONS: + # NOTE: This may be greater if split in the middle of a document + # because that document may then exist in both folds + self.assertGreaterEqual(total_docs, count_all_once) + else: + self.assertEqual(total_docs, count_all_once) + + def test_folds_keep_all_anns(self): + total_anns = 0 + for fold in self.folds: + anns = kfold.count_all_annotations(fold) + total_anns += anns + count_all_once = kfold.count_all_annotations(self.mct_export) + self.assertEqual(total_anns, count_all_once) + + def test_1fold_same_as_orig(self): + folds = kfold.get_fold_creator(self.mct_export, 1, split_type=self.SPLIT_TYPE).create_folds() + self.assertEqual(len(folds), 1) + fold, = folds + self.assertIsInstance(fold, dict) + self.assertIsMCTExport(fold) + self.assertEqual( + nullify_doc_names_proj_ids(self.mct_export), + nullify_doc_names_proj_ids(fold), + ) + + def test_has_reasonable_annotations_per_folds(self): + anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds] + print(f"ANNS per folds:\n{anns_per_folds}") + docs_per_folds = [kfold.count_all_docs(fold) for fold in self.folds] + print(f"DOCS per folds:\n{docs_per_folds}") + + +# this is a taylor-made export that +# just contains a few "documents" +# with the fake CUIs "annotated" +NEW_EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export_FAKE_CONCEPTS.json") + + +class KFoldCreatorPerAnnsTests(KFoldCreatorTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldCreatorPerWeightedDocsTests(KFoldCreatorTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + # should have a total of 435, so 145 per in ideal world + # but we'll allow the following deviation + PERMITTED_MAX_DEVIATION_IN_ANNS = 5 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.total_anns = kfold.count_all_annotations(cls.mct_export) + cls.expected_anns_per_fold = cls.total_anns // cls.K + cls.expected_lower_bound = cls.expected_anns_per_fold - cls.PERMITTED_MAX_DEVIATION_IN_ANNS + cls.expected_upper_bound = cls.expected_anns_per_fold + cls.PERMITTED_MAX_DEVIATION_IN_ANNS + + def test_has_reasonable_annotations_per_folds(self): + anns_per_folds = [kfold.count_all_annotations(fold) for fold in self.folds] + for nr, anns in enumerate(anns_per_folds): + with self.subTest(f"Fold-{nr}"): + self.assertGreater(anns, self.expected_lower_bound) + self.assertLess(anns, self.expected_upper_bound) + # NOTE: as of testing, this will split [146, 145, 144] + # whereas regular per-docs split will have [140, 163, 132] + + +class KFoldCreatorNewExportTests(KFoldCreatorTests): + EXPORT_PATH = NEW_EXPORT_PATH + + +class KFoldCreatorNewExportAnnsTests(KFoldCreatorNewExportTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldCreatorNewExportWeightedDocsTests(KFoldCreatorNewExportTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + + +class KFoldCATTests(MCTExportTests): + _names = ['fps', 'fns', 'tps', 'prec', 'rec', 'f1', 'counts', 'examples'] + EXPORT_PATH = NEW_EXPORT_PATH + CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "examples") + TOLERANCE_PLACES = 10 # tolerance of 10 digits + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.cat = CAT.load_model_pack(cls.CAT_PATH) + + def setUp(self) -> None: + super().setUp() + self.reg_stats = self.cat._print_stats(self.mct_export, do_print=False) + # TODO - remove + self.maxDiff = 4000 + + # NOTE: Due to floating point errors, sometimes we may get slightly different results + def assertDictsAlmostEqual(self, d1: Dict[str, Union[int, float]], d2: Dict[str, Union[int, float]], + tolerance_places: Optional[int] = None) -> None: + self.assertEqual(d1.keys(), d2.keys()) + tol = tolerance_places if tolerance_places is not None else self.TOLERANCE_PLACES + for k in d1: + v1, v2 = d1[k], d2[k] + self.assertAlmostEqual(v1, v2, places=tol) + + +class KFoldStatsConsistencyTests(KFoldCATTests): + + def test_mct_export_valid(self): + self.assertIsMCTExport(self.mct_export) + + def test_stats_consistent(self): + stats = self.cat._print_stats(self.mct_export, do_print=False) + for name, stats1, stats2 in zip(self._names, self.reg_stats, stats): + with self.subTest(name): + # NOTE: These should be EXACTLY equal since there shouldn't be + # any different additions and the like + self.assertEqual(stats1, stats2) + + +class KFoldMetricsTests(KFoldCATTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS + + def test_metrics_1_fold_same_as_normal(self): + stats = kfold.get_k_fold_stats(self.cat, self.mct_export, k=1, + split_type=self.SPLIT_TYPE) + for name, reg, folds1 in zip(self._names, self.reg_stats, stats): + with self.subTest(name): + if name != 'examples': + # NOTE: These may not be exactly equal due to floating point errors + self.assertDictsAlmostEqual(reg, folds1) + else: + self.assertEqual(reg, folds1) + + +class KFoldPerAnnsMetricsTests(KFoldMetricsTests): + SPLIT_TYPE = kfold.SplitType.ANNOTATIONS + + +class KFoldWeightedDocsMetricsTests(KFoldMetricsTests): + SPLIT_TYPE = kfold.SplitType.DOCUMENTS_WEIGHTED + + +class KFoldDuplicatedTests(KFoldCATTests): + COPIES = 3 + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.docs_in_orig = kfold.count_all_docs(cls.mct_export) + cls.anns_in_orig = kfold.count_all_annotations(cls.mct_export) + cls.data_copied: kfold.MedCATTrainerExport = deepcopy(cls.mct_export) + for project in cls.data_copied['projects']: + documents_list = project['documents'] + copies = documents_list + [ + {k: v if k != 'name' else f"{v}_cp_{nr}" for k, v in doc.items()} for nr in range(cls.COPIES - 1) + for doc in documents_list + ] + project['documents'] = copies + cls.docs_in_copy = kfold.count_all_docs(cls.data_copied) + cls.anns_in_copy = kfold.count_all_annotations(cls.data_copied) + cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES) + cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES) + + # some stats with real model/data will be e.g 0.99 vs 0.9747 + # so in that case, lower it to 1 or so + _stats_consistency_tolerance = 8 + + def test_stats_consistent(self): + for name, one, two in zip(self._names, self.stats_copied, self.stats_copied_2): + with self.subTest(name): + if name == 'examples': + # examples are hard + # sometimes they differ by quite a lot + for etype in one: + ev1, ev2 = one[etype], two[etype] + with self.subTest(f"{name}-{etype}"): + self.assertEqual(ev1.keys(), ev2.keys()) + for cui in ev1: + per_cui_examples1 = ev1[cui] + per_cui_examples2 = ev2[cui] + with self.subTest(f"{name}-{etype}-{cui}-[{self.cat.cdb.cui2preferred_name.get(cui, cui)}]"): + self.assertEqual(len(per_cui_examples1), len(per_cui_examples2), "INCORRECT NUMBER OF ITEMS") + for ex1, ex2 in zip(per_cui_examples1, per_cui_examples2): + self.assertDictsAlmostEqual(ex1, ex2, tolerance_places=self._stats_consistency_tolerance) + continue + self.assertEqual(one, two) + + def test_copy_has_correct_number_documents(self): + self.assertEqual(self.COPIES * self.docs_in_orig, self.docs_in_copy) + + def test_copy_has_correct_number_annotations(self): + self.assertEqual(self.COPIES * self.anns_in_orig, self.anns_in_copy) + + def test_3_fold_identical_folds(self): + folds = kfold.get_fold_creator(self.data_copied, nr_of_folds=self.COPIES, + split_type=kfold.SplitType.DOCUMENTS).create_folds() + self.assertEqual(len(folds), self.COPIES) + for nr, fold in enumerate(folds): + with self.subTest(f"Fold-{nr}"): + # if they're all equal to original, they're eqaul to each other + self.assertEqual( + nullify_doc_names_proj_ids(fold), + nullify_doc_names_proj_ids(self.mct_export) + ) + + def test_metrics_3_fold(self): + stats_simple = self.reg_stats + for name, old, new in zip(self._names, stats_simple, self.stats_copied): + if name == 'examples': + continue + # with self.subTest(name): + if name in ("fps", "fns", "tps", "counts"): + # count should be triples + pass + if name in ("prec", "rec", "f1"): + # these should average to the same ?? + all_keys = old.keys() | new.keys() + for cui in all_keys: + cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui) + with self.subTest(f"{name}-{cui} [{cuiname}]"): + self.assertIn(cui, old.keys(), f"CUI '{cui}' ({cuiname}) not in old") + self.assertIn(cui, new.keys(), f"CUI '{cui}' ({cuiname}) not in new") + v1, v2 = old[cui], new[cui] + self.assertEqual(v1, v2, f"Values not equal for {cui} ({self.cat.cdb.cui2preferred_name.get(cui, cui)})") diff --git a/tests/stats/test_mctexport.py b/tests/stats/test_mctexport.py new file mode 100644 index 000000000..8ef11f556 --- /dev/null +++ b/tests/stats/test_mctexport.py @@ -0,0 +1,38 @@ +import os +import json + +from medcat.stats import mctexport + +import unittest + +from .helpers import MCTExportPydanticModel + + +class MCTExportIterationTests(unittest.TestCase): + EXPORT_PATH = os.path.join(os.path.dirname(__file__), "..", + "resources", "medcat_trainer_export.json") + EXPECTED_DOCS = 27 + EXPECTED_ANNS = 435 + + @classmethod + def setUpClass(cls) -> None: + with open(cls.EXPORT_PATH) as f: + cls.mct_export: mctexport.MedCATTrainerExport = json.load(f) + + def test_conforms_to_template(self): + # NOTE: This uses pydantic to make sure that the MedCATTrainerExport + # type matches the actual export format + model_instance = MCTExportPydanticModel(**self.mct_export) + self.assertIsInstance(model_instance, MCTExportPydanticModel) + + def test_iterates_over_all_docs(self): + self.assertEqual(mctexport.count_all_docs(self.mct_export), self.EXPECTED_DOCS) + + def test_iterates_over_all_anns(self): + self.assertEqual(mctexport.count_all_annotations(self.mct_export), self.EXPECTED_ANNS) + + def test_gets_correct_nr_of_annotations_per_doc(self): + for project in self.mct_export['projects']: + for doc in project["documents"]: + with self.subTest(f"Proj-{project['name']} ({project['id']})-{doc['name']} ({doc['id']})"): + self.assertEqual(mctexport.get_nr_of_annotations(doc), len(doc["annotations"])) diff --git a/tests/test_cat.py b/tests/test_cat.py index ce1b62d98..780039473 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -301,9 +301,9 @@ def _test_train_superivsed(self, temp_file: str): data_path = self.SUPERVISED_TRAINING_JSON ckpt_dir_path = temp_file checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize) - fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path, - checkpoint=checkpoint, - nepochs=nepochs) + fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised_from_json(data_path, + checkpoint=checkpoint, + nepochs=nepochs) checkpoints = [f for f in os.listdir(ckpt_dir_path) if "checkpoint-" in f] self.assertEqual({}, fp) self.assertEqual({}, fn) @@ -328,13 +328,11 @@ def _test_resume_supervised_training(self, temp_file: str): data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json") ckpt_dir_path = temp_file checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize) - self.undertest.train_supervised(data_path, - checkpoint=checkpoint, - nepochs=nepochs_train) - fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path, - checkpoint=checkpoint, - nepochs=nepochs_train+nepochs_retrain, - is_resumed=True) + self.undertest.train_supervised_from_json(data_path, + checkpoint=checkpoint, + nepochs=nepochs_train) + fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised_from_json( + data_path, checkpoint=checkpoint, nepochs=nepochs_train+nepochs_retrain, is_resumed=True) checkpoints = [f for f in os.listdir(ckpt_dir_path) if "checkpoint-" in f] self.assertEqual({}, fp) self.assertEqual({}, fn) @@ -351,15 +349,15 @@ def _test_resume_supervised_training(self, temp_file: str): def test_train_supervised_does_not_retain_MCT_filters_default(self, extra_cui_filter=None): data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export_filtered.json") before = str(self.undertest.config.linking.filters) - self.undertest.train_supervised(data_path, nepochs=1, use_filters=True, extra_cui_filter=extra_cui_filter) + self.undertest.train_supervised_from_json(data_path, nepochs=1, use_filters=True, extra_cui_filter=extra_cui_filter) after = str(self.undertest.config.linking.filters) self.assertEqual(before, after) def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, retain_extra_cui_filter=False): data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export_filtered.json") before = str(self.undertest.config.linking.filters) - self.undertest.train_supervised(data_path, nepochs=1, use_filters=True, retain_filters=True, - extra_cui_filter=extra_cui_filter, retain_extra_cui_filter=retain_extra_cui_filter) + self.undertest.train_supervised_from_json(data_path, nepochs=1, use_filters=True, retain_filters=True, + extra_cui_filter=extra_cui_filter, retain_extra_cui_filter=retain_extra_cui_filter) after = str(self.undertest.config.linking.filters) self.assertNotEqual(before, after) with open(data_path, 'r') as f: @@ -701,7 +699,7 @@ def _get_meta_cat(meta_cat_dir): config=config) os.makedirs(meta_cat_dir, exist_ok=True) json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json") - meta_cat.train(json_path, save_dir_path=meta_cat_dir) + meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir) return meta_cat @@ -712,7 +710,7 @@ class TestLoadingOldWeights(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.cdb = CDB.load(cls.cdb_path) - cls.wf = cls.cdb.config.linking.weighted_average_function + cls.wf = cls.cdb.weighted_average_function def test_can_call_weights(self): res = self.wf(step=1) diff --git a/tests/test_cdb_maker.py b/tests/test_cdb_maker.py index f84e47b15..f454ebe7d 100644 --- a/tests/test_cdb_maker.py +++ b/tests/test_cdb_maker.py @@ -132,6 +132,24 @@ def setUpClass(cls): def tearDownClass(cls) -> None: cls.maker.destroy_pipe() + # NOTE: The following tests are state-dependent. That is to say, + # if the order in which they're executed changes, they may fail. + # They currently rely on the fact that test methods are executed + # in anlphabetic order. But this is overall not good test design + # since failure of one unit could lead to the failure of another + # in unsexpected ways (since there's an expectation on the state). + # + # e.g, if I run: + # python -m unittest\ + # tests.test_cdb_maker.B_CDBMakerEditTests.test_bd_addition_of_context_vector_positive\ + # tests.test_cdb_maker.B_CDBMakerEditTests.test_bc_filter_by_cui\ + # tests.test_cdb_maker.B_CDBMakerEditTests.test_bb_removal_of_name\ + # tests.test_cdb_maker.B_CDBMakerEditTests.test_ba_addition_of_new_name + # Then there will be failures in `test_ba_addition_of_new_name` and `test_bb_removal_of_name` + # due to the changes in state. + # + # Though to make it clear, in the standard configuration the tests run in the + # "correct" order and are successful. def test_ba_addition_of_new_name(self): self.cdb.add_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config), name_status='P', full_build=True) self.assertEqual(len(self.cdb.name2cuis), 6, "Should equal 6") @@ -142,7 +160,7 @@ def test_ba_addition_of_new_name(self): self.assertIn('my~:~new~name~.', self.cdb.name2cuis2status) def test_bb_removal_of_name(self): - self.cdb.remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config)) + self.cdb._remove_names(cui='C0000239', names=prepare_name('MY: new,-_! Name.', self.maker.pipe.spacy_nlp, {}, self.config)) self.assertEqual(len(self.cdb.name2cuis), 5, "Should equal 5") self.assertNotIn('my:newname.', self.cdb.name2cuis2status) diff --git a/tests/test_config.py b/tests/test_config.py index ce6ed76eb..bfd440a78 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,7 @@ import pickle import tempfile from medcat.config import Config, MixingConfig, VersionInfo, General, LinkingFilters +from medcat.config import UseOfOldConfigOptionException, Linking from pydantic import ValidationError import os @@ -208,6 +209,13 @@ def test_config_hash_recalc_same_changed(self): h2 = config.get_hash() self.assertEqual(h1, h2) + def test_can_save_load(self): + config = Config() + with tempfile.NamedTemporaryFile() as file: + config.save(file.name) + config2 = Config.load(file.name) + self.assertEqual(config, config2) + class ConfigLinkingFiltersTests(unittest.TestCase): @@ -228,5 +236,36 @@ def test_not_allow_empty_dict_for_cuis_exclude(self): LinkingFilters(cuis_exclude={}) +class BackwardsCompatibilityTests(unittest.TestCase): + + def setUp(self) -> None: + self.config = Config() + + def test_use_weighted_average_function_identifier_nice_error(self): + with self.assertRaises(UseOfOldConfigOptionException): + self.config.linking.weighted_average_function(0) + + def test_use_weighted_average_function_dict_nice_error(self): + with self.assertRaises(UseOfOldConfigOptionException): + self.config.linking['weighted_average_function'](0) + + +class BackwardsCompatibilityWafPayloadTests(unittest.TestCase): + arg = 'weighted_average_function' + + @classmethod + def setUpClass(cls) -> None: + cls.config = Config() + with cls.assertRaises(cls, UseOfOldConfigOptionException) as cls.context: + cls.config.linking.weighted_average_function(0) + cls.raised = cls.context.exception + + def test_exception_has_correct_conf_type(self): + self.assertIs(self.raised.conf_type, Linking) + + def test_exception_has_correct_arg(self): + self.assertEqual(self.raised.arg_name, self.arg) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_meta_cat.py b/tests/test_meta_cat.py index 8cd444668..ead082c0b 100644 --- a/tests/test_meta_cat.py +++ b/tests/test_meta_cat.py @@ -10,6 +10,7 @@ import spacy from spacy.tokens import Span + class MetaCATTests(unittest.TestCase): @classmethod @@ -17,7 +18,7 @@ def setUpClass(cls) -> None: tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained('prajjwal1/bert-tiny')) config = ConfigMetaCAT() config.general['category_name'] = 'Status' - config.train['nepochs'] = 1 + config.train['nepochs'] = 2 config.model['input_size'] = 100 cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) @@ -29,14 +30,16 @@ def tearDown(self) -> None: shutil.rmtree(self.tmp_dir) def test_train(self): - json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') - results = self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) - - self.assertEqual(results['report']['weighted avg']['f1-score'], 1.0) + json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', + 'mct_export_for_meta_cat_test.json') + results = self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir) + if self.meta_cat.config.model.phase_number != 1: + self.assertEqual(results['report']['weighted avg']['f1-score'], 1.0) def test_save_load(self): - json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') - self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) + json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', + 'mct_export_for_meta_cat_test.json') + self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir) self.meta_cat.save(self.tmp_dir) n_meta_cat = MetaCAT.load(self.tmp_dir) @@ -53,17 +56,18 @@ def _prepare_doc_w_spangroup(self, spangroup_name: str): Span.set_extension('meta_anns', default=None, force=True) nlp = spacy.blank("en") doc = nlp("Pt has diabetes and copd.") - span_0 = doc.char_span(7,15, label="diabetes") + span_0 = doc.char_span(7, 15, label="diabetes") assert span_0.text == 'diabetes' - span_1 = doc.char_span(20,24, label="copd") + span_1 = doc.char_span(20, 24, label="copd") assert span_1.text == 'copd' doc.spans[spangroup_name] = [span_0, span_1] return doc def test_predict_spangroup(self): - json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') - self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) + json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', + 'mct_export_for_meta_cat_test.json') + self.meta_cat.train_from_json(json_path, save_dir_path=self.tmp_dir) self.meta_cat.save(self.tmp_dir) n_meta_cat = MetaCAT.load(self.tmp_dir) @@ -90,5 +94,29 @@ def test_predict_spangroup(self): n_meta_cat.config.general.span_group = None +class MetaCATBertTest(MetaCATTests): + @classmethod + def setUpClass(cls) -> None: + tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained('prajjwal1/bert-tiny')) + config = ConfigMetaCAT() + config.general['category_name'] = 'Status' + config.train['nepochs'] = 2 + config.model['input_size'] = 100 + config.train['batch_size'] = 64 + config.model['model_name'] = 'bert' + + cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) + cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") + os.makedirs(cls.tmp_dir, exist_ok=True) + + def test_two_phase(self): + self.meta_cat.config.model['phase_number'] = 1 + self.test_train() + self.meta_cat.config.model['phase_number'] = 2 + self.test_train() + + self.meta_cat.config.model['phase_number'] = 0 + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/saving/test_envsnapshot.py b/tests/utils/saving/test_envsnapshot.py new file mode 100644 index 000000000..16bee1ffb --- /dev/null +++ b/tests/utils/saving/test_envsnapshot.py @@ -0,0 +1,105 @@ +from typing import Any +import platform +import os +import tempfile +import json +import zipfile + +from medcat.cat import CAT +from medcat.utils.saving import envsnapshot + +import unittest + + +def list_zip_contents(zip_file_path): + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + return zip_ref.namelist() + + +class DirectDependenciesTests(unittest.TestCase): + + def setUp(self) -> None: + self.direct_deps = envsnapshot.get_direct_dependencies() + + def test_nonempty(self): + self.assertTrue(self.direct_deps) + + def test_does_not_contain_versions(self, version_starters: str = '<=>~'): + for dep in self.direct_deps: + for vs in version_starters: + with self.subTest(f"DEP '{dep}' check for '{vs}'"): + self.assertNotIn(vs, dep) + + def test_deps_are_installed_packages(self): + for dep in self.direct_deps: + with self.subTest(f"Has '{dep}'"): + envsnapshot.pkg_resources.require(dep) + + +class EnvSnapshotAloneTests(unittest.TestCase): + + def setUp(self) -> None: + self.env_info = envsnapshot.get_environment_info() + + def test_info_is_dict(self): + self.assertIsInstance(self.env_info, dict) + + def test_info_is_not_empty(self): + self.assertTrue(self.env_info) + + def assert_has_target(self, target: str, expected: Any): + self.assertIn(target, self.env_info) + py_ver = self.env_info[target] + self.assertEqual(py_ver, expected) + + def test_has_os(self): + self.assert_has_target("os", platform.platform()) + + def test_has_py_ver(self): + self.assert_has_target("python_version", platform.python_version()) + + def test_has_cpu_arch(self): + self.assert_has_target("cpu_architecture", platform.machine()) + + def test_has_dependencies(self, name: str = "dependencies"): + # NOTE: just making sure it's a anon-empty list + self.assertIn(name, self.env_info) + deps = self.env_info[name] + self.assertTrue(deps) + + def test_all_direct_dependencies_are_installed(self): + deps = self.env_info['dependencies'] + direct_deps = envsnapshot.get_direct_dependencies() + self.assertEqual(len(deps), len(direct_deps)) + + +CAT_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "..", "examples") +ENV_SNAPSHOT_FILE_NAME = envsnapshot.ENV_SNAPSHOT_FILE_NAME + + +class EnvSnapshotInCATTests(unittest.TestCase): + expected_env = envsnapshot.get_environment_info() + + @classmethod + def setUpClass(cls) -> None: + cls.cat = CAT.load_model_pack(CAT_PATH) + cls._temp_dir = tempfile.TemporaryDirectory() + mpn = cls.cat.create_model_pack(cls._temp_dir.name) + cls.cat_folder = os.path.join(cls._temp_dir.name, mpn) + cls.envrion_file_path = os.path.join(cls.cat_folder, ENV_SNAPSHOT_FILE_NAME) + + def test_has_environment(self): + self.assertTrue(os.path.exists(self.envrion_file_path)) + + def test_eviron_saved(self): + with open(self.envrion_file_path) as f: + saved_info: dict = json.load(f) + self.assertEqual(saved_info.keys(), self.expected_env.keys()) + for k in saved_info: + with self.subTest(k): + v1, v2 = saved_info[k], self.expected_env[k] + self.assertEqual(v1, v2) + + def test_zip_has_env_snapshot(self): + filenames = list_zip_contents(self.cat_folder + ".zip") + self.assertIn(ENV_SNAPSHOT_FILE_NAME, filenames) diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index c2c44da16..cb26312f0 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -10,6 +10,7 @@ from medcat.vocab import Vocab from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY +from medcat.utils.saving.envsnapshot import ENV_SNAPSHOT_FILE_NAME import medcat.utils.saving.coding as _ @@ -60,6 +61,7 @@ class ModelCreationTests(unittest.TestCase): json_model_pack = tempfile.TemporaryDirectory() EXAMPLES = os.path.join(os.path.dirname( os.path.realpath(__file__)), "..", "..", "..", "examples") + EXCEPTIONAL_JSONS = ['model_card.json', ENV_SNAPSHOT_FILE_NAME] @classmethod def setUpClass(cls) -> None: @@ -95,7 +97,7 @@ def test_dill_to_json(self): SPECIALITY_NAMES) - len(ONE2MANY)) for json in jsons: with self.subTest(f'JSON {json}'): - if json.endswith('model_card.json'): + if any(json.endswith(exception) for exception in self.EXCEPTIONAL_JSONS): continue # ignore model card here if any(name in json for name in ONE2MANY): # ignore cui2many and name2many @@ -117,10 +119,6 @@ def test_round_trip(self): # The spacy model has full path in the loaded model, thus won't be equal cat.config.general.spacy_model = os.path.basename( cat.config.general.spacy_model) - # There can also be issues with loading the config.linking.weighted_average_function from file - # This should be fixed with newer models, - # but the example model is older, so has the older functionalitys - cat.config.linking.weighted_average_function = self.undertest.config.linking.weighted_average_function self.assertEqual(cat.config.asdict(), self.undertest.config.asdict()) self.assertEqual(cat.cdb.config, self.undertest.cdb.config) self.assertEqual(len(cat.vocab.vocab), len(self.undertest.vocab.vocab)) diff --git a/tests/utils/test_cdb_state.py b/tests/utils/test_cdb_state.py new file mode 100644 index 000000000..068af128b --- /dev/null +++ b/tests/utils/test_cdb_state.py @@ -0,0 +1,113 @@ +import unittest +import os +from unittest import mock +from typing import Callable, Any, Dict +import tempfile + +from medcat.utils.cdb_state import captured_state_cdb, CDBState, copy_cdb_state +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.cat import CAT + + +class StateTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) + cls.vocab.make_unigram_table() + cls.cdb.config.general.spacy_model = "en_core_web_md" + cls.meta_cat_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") + cls.undertest = CAT(cdb=cls.cdb, config=cls.cdb.config, vocab=cls.vocab, meta_cats=[]) + cls.initial_state = copy_cdb_state(cls.cdb) + + @classmethod + def _set_info(cls, k: str, v: Any, info_dict: Dict): + info_dict[k] = (len(v), len(str(v))) + + @classmethod + def do_smth_for_each_state_var(cls, cdb: CDB, callback: Callable[[str, Any], None]) -> None: + for k in CDBState.__annotations__: + v = getattr(cdb, k) + callback(k, v) + + +class StateSavedTests(StateTests): + on_disk = False + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + # capture state + with captured_state_cdb(cls.cdb, save_state_to_disk=cls.on_disk): + # clear state + cls.do_smth_for_each_state_var(cls.cdb, lambda k, v: v.clear()) + cls.cleared_state = copy_cdb_state(cls.cdb) + # save after state - should be equal to before + cls.restored_state = copy_cdb_state(cls.cdb) + + def test_state_saved(self): + nr_of_targets = len(CDBState.__annotations__) + self.assertGreater(nr_of_targets, 0) + self.assertEqual(len(self.initial_state), nr_of_targets) + self.assertEqual(len(self.cleared_state), nr_of_targets) + self.assertEqual(len(self.restored_state), nr_of_targets) + + def test_clearing_worked(self): + self.assertNotEqual(self.initial_state, self.cleared_state) + for k, v in self.cleared_state.items(): + with self.subTest(k): + # length is 0 + self.assertFalse(v) + + def test_state_restored(self): + self.assertEqual(self.initial_state, self.restored_state) + + +class StateSavedOnDiskTests(StateSavedTests): + on_disk = True + _named_tempory_file = tempfile.NamedTemporaryFile + + @classmethod + def saved_name_temp_file(cls): + tf = cls._named_tempory_file() + cls.temp_file_name = tf.name + return tf + + @classmethod + def setUpClass(cls) -> None: + with mock.patch("builtins.open", side_effect=open) as cls.popen: + with mock.patch("tempfile.NamedTemporaryFile", side_effect=cls.saved_name_temp_file) as cls.pntf: + return super().setUpClass() + + def test_temp_file_called(self): + self.pntf.assert_called_once() + + def test_saved_on_disk(self): + self.popen.assert_called() + self.assertGreaterEqual(self.popen.call_count, 2) + self.popen.assert_has_calls([mock.call(self.temp_file_name, 'wb'), + mock.call(self.temp_file_name, 'rb')]) + + +class StateWithTrainingTests(StateTests): + SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "..", "resources", "medcat_trainer_export.json") + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + with captured_state_cdb(cls.cdb): + # do training + cls.undertest.train_supervised_from_json(cls.SUPERVISED_TRAINING_JSON) + cls.after_train_state = copy_cdb_state(cls.cdb) + cls.restored_state = copy_cdb_state(cls.cdb) + + +class StateRestoredAfterTrain(StateWithTrainingTests): + + def test_train_state_changed(self): + self.assertNotEqual(self.initial_state, self.after_train_state) + + def test_restored_state_same(self): + self.assertDictEqual(self.initial_state, self.restored_state) diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 000000000..d1a7262e7 --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,121 @@ +from medcat.config import Config +from medcat.utils.saving.coding import default_hook, CustomDelegatingEncoder +from medcat.utils import config_utils +from medcat import config as main_config +from medcat import config_meta_cat +from medcat import config_transformers_ner +from medcat import config_rel_cat +import json +import os + +import unittest + +OLD_STYLE_DICT = {'py/object': 'medcat.config.VersionInfo', + 'py/state': { + '__dict__': { + 'history': ['0c0de303b6dc0020',], + 'meta_cats': [], + 'cdb_info': { + 'Number of concepts': 785910, + 'Number of names': 2480049, + 'Number of concepts that received training': 378746, + 'Number of seen training examples in total': 1863973060, + 'Average training examples per concept': { + 'py/reduce': [{'py/function': 'numpy.core.multiarray.scalar'},] + } + }, + 'performance': {'ner': {}, 'meta': {}}, + 'description': 'No description', + 'id': 'ff4f4e00bc97de58', + 'last_modified': '26 April 2024', + 'location': None, + 'ontology': ['ONTOLOGY1'], + 'medcat_version': '1.10.2' + }, + '__fields_set__': { + 'py/set': ['id', 'ontology', 'description', 'history', + 'location', 'medcat_version', 'last_modified', + 'meta_cats', 'cdb_info', 'performance'] + }, + '__private_attribute_values__': {} + } + } + + +NEW_STYLE_DICT = json.loads(json.dumps(Config().asdict(), cls=CustomDelegatingEncoder.def_inst), + object_hook=default_hook) + + +class ConfigUtilsTests(unittest.TestCase): + + def test_identifies_old_style_dict(self): + self.assertTrue(config_utils.is_old_type_config_dict(OLD_STYLE_DICT)) + + def test_identifies_new_style_dict(self): + self.assertFalse(config_utils.is_old_type_config_dict(NEW_STYLE_DICT)) + + +class OldFormatJsonTests(unittest.TestCase): + + def assert_knows_old_format(self, file_path: str): + with open(file_path) as f: + d = json.load(f) + self.assertTrue(config_utils.is_old_type_config_dict(d)) + + +class OldConfigLoadTests(OldFormatJsonTests): + JSON_PICKLE_FILE_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_config.json" + ) + EXPECTED_VERSION_HISTORY = ['0c0de303b6dc0020',] + + def test_knows_is_old_format(self): + self.assert_knows_old_format(self.JSON_PICKLE_FILE_PATH) + + def test_loads_old_style_correctly(self): + cnf: main_config.Config = main_config.Config.load(self.JSON_PICKLE_FILE_PATH) + self.assertEqual(cnf.version.history, self.EXPECTED_VERSION_HISTORY) + + +class MetaCATConfigTests(OldFormatJsonTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_meta_cat_config.json" + ) + EXPECTED_TARGET = -100 + TARGET_CLASS = config_meta_cat.ConfigMetaCAT + + @classmethod + def get_target(cls, cnf): + return cnf.general.seed + + def test_knows_is_old_format(self): + self.assert_knows_old_format(self.META_CAT_OLD_PATH) + + def test_can_load_old_format_correctly(self): + cnf = self.TARGET_CLASS.load(self.META_CAT_OLD_PATH) + self.assertIsInstance(cnf, self.TARGET_CLASS) + self.assertEqual(self.get_target(cnf), self.EXPECTED_TARGET) + + +class TNERCATConfigTests(MetaCATConfigTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_tner_config.json" + ) + EXPECTED_TARGET = -100 + TARGET_CLASS = config_transformers_ner.ConfigTransformersNER + + @classmethod + def get_target(cls, cnf): + return cnf.general.pipe_batch_size_in_chars + + +class RelCATConfigTests(MetaCATConfigTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_rel_cat_config.json" + ) + EXPECTED_TARGET = 100_000 + TARGET_CLASS = config_rel_cat.ConfigRelCAT + + @classmethod + def get_target(cls, cnf): + return cnf.train.lr diff --git a/webapp/.gitignore b/webapp/.gitignore deleted file mode 100644 index fc6ea2e67..000000000 --- a/webapp/.gitignore +++ /dev/null @@ -1,6 +0,0 @@ -webapp/data/* -!webapp/data/.keep -webapp/db/* -!webapp/db/.keep -webapp/models/* -!webapp/models/.keep \ No newline at end of file diff --git a/webapp/README.md b/webapp/README.md deleted file mode 100644 index 128741d05..000000000 --- a/webapp/README.md +++ /dev/null @@ -1 +0,0 @@ -This is a demo application for MedCAT, please note that it was made to be as unreadable as possible - I appologize to anyone that has to use this, it was not done on on purpose. diff --git a/webapp/docker-compose.yml b/webapp/docker-compose.yml deleted file mode 100644 index 60a392990..000000000 --- a/webapp/docker-compose.yml +++ /dev/null @@ -1,26 +0,0 @@ -version: '3.4' - -services: - medcatweb: - build: - network: host - context: ./webapp - command: > - bash -c "/etc/init.d/cron start && - python /webapp/manage.py runserver 0.0.0.0:8000" - volumes: - - ./webapp/data:/webapp/data - - ./webapp/db:/webapp/db - - ./webapp/models:/webapp/models - - ./envs/env_db_backup:/etc/environment - - medcat_data:/medcat_data - ports: - - "80:8000" - env_file: - - ./envs/env_medmen - - ./envs/env_db_backup - tty: true - -volumes: - medcat_data: - driver: local diff --git a/webapp/envs/env_db_backup b/webapp/envs/env_db_backup deleted file mode 100644 index 6071abf41..000000000 --- a/webapp/envs/env_db_backup +++ /dev/null @@ -1,8 +0,0 @@ -DB_BACKUP_ON_S3=False -DB_BACKUP_LOCATION=demo-db-backup/ -DB_BACKUP_EVERY_MINS=720 -DB_BACKUP_RETRY_BACKOFF_MINS=5 -ACCESS_KEY= -SECRET_KEY= -BUCKET_NAME= -DELETE_LOGS_OLDER_THAN=7 \ No newline at end of file diff --git a/webapp/envs/env_medmen b/webapp/envs/env_medmen deleted file mode 100644 index 9952ef961..000000000 --- a/webapp/envs/env_medmen +++ /dev/null @@ -1 +0,0 @@ -MODEL_PACK_PATH=// diff --git a/webapp/webapp/.dockerignore b/webapp/webapp/.dockerignore deleted file mode 100644 index 285877ca8..000000000 --- a/webapp/webapp/.dockerignore +++ /dev/null @@ -1,2 +0,0 @@ -data -models \ No newline at end of file diff --git a/webapp/webapp/Dockerfile b/webapp/webapp/Dockerfile deleted file mode 100644 index 21d19078f..000000000 --- a/webapp/webapp/Dockerfile +++ /dev/null @@ -1,37 +0,0 @@ -FROM python:3.7 - -# Create the required folders -RUN mkdir -p /webapp/models - -# Copy everything -COPY . /webapp - -ENV VOCAB_URL=https://medcat.rosalind.kcl.ac.uk/media/vocab.dat -ENV CDB_URL=https://medcat.rosalind.kcl.ac.uk/media/cdb-medmen-v1.dat - -ENV CDB_PATH=/webapp/models/cdb.dat -ENV VOCAB_PATH=/webapp/models/vocab.dat - -# Create the data directory -RUN mkdir -p /medcat_data - -# Set the pythonpath -WORKDIR /webapp - -RUN pip install -r requirements.txt - -# Get the spacy model -RUN python -m spacy download en_core_web_md - -# Build the db -RUN python manage.py makemigrations && \ - python manage.py makemigrations demo && \ - python manage.py migrate && \ - python manage.py migrate demo && \ - python manage.py collectstatic --noinput - -# Create the db backup cron job -RUN apt-get update && apt-get install -y --no-install-recommends apt-utils cron sqlite3 libsqlite3-dev -COPY etc/cron.d/db-backup-cron /etc/cron.d/db-backup-cron -RUN chmod 0644 /etc/cron.d/db-backup-cron -RUN crontab /etc/cron.d/db-backup-cron diff --git a/webapp/webapp/data/.keep b/webapp/webapp/data/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/db/.keep b/webapp/webapp/db/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/demo/admin.py b/webapp/webapp/demo/admin.py deleted file mode 100644 index 7cce1297b..000000000 --- a/webapp/webapp/demo/admin.py +++ /dev/null @@ -1,16 +0,0 @@ -from django.contrib import admin -from .models import * - -admin.site.register(Downloader) -admin.site.register(MedcatModel) - -def remove_text(modeladmin, request, queryset): - UploadedText.objects.all().delete() - -class UploadedTextAdmin(admin.ModelAdmin): - model = UploadedText - actions = [remove_text] - -# Register your models here. -admin.site.register(UploadedText, UploadedTextAdmin) - diff --git a/webapp/webapp/demo/apps.py b/webapp/webapp/demo/apps.py deleted file mode 100644 index 57920c332..000000000 --- a/webapp/webapp/demo/apps.py +++ /dev/null @@ -1,5 +0,0 @@ -from django.apps import AppConfig - - -class DemoConfig(AppConfig): - name = 'demo' diff --git a/webapp/webapp/demo/db_backup.py b/webapp/webapp/demo/db_backup.py deleted file mode 100644 index e637f249a..000000000 --- a/webapp/webapp/demo/db_backup.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -from django.core import management -from django.conf import settings -from django_cron import CronJobBase, Schedule - - -class DbBackup(CronJobBase): - - RUN_EVERY_MINS = int(os.environ.get("DB_BACKUP_EVERY_MINS", 60 * 12)) - RETRY_AFTER_FAILURE_MINS = int(os.environ.get("DB_BACKUP_RETRY_BACKOFF_MINS", 5)) - - schedule = Schedule(run_every_mins=RUN_EVERY_MINS, retry_after_failure_mins=RETRY_AFTER_FAILURE_MINS) - code = "demo.db_backup.DbBackup" - - def __init__(self): - backup_location = settings.DBBACKUP_STORAGE_OPTIONS["location"] - os.makedirs(backup_location, exist_ok=True) - - def do(self): - management.call_command("dbbackup", "--noinput", "-z") diff --git a/webapp/webapp/demo/forms.py b/webapp/webapp/demo/forms.py deleted file mode 100644 index 100efc438..000000000 --- a/webapp/webapp/demo/forms.py +++ /dev/null @@ -1,48 +0,0 @@ -from email.policy import default -from django import forms -from .models import Downloader - - -class DownloaderForm(forms.ModelForm): - consent = forms.BooleanField(required=True, label=( - f"I consent to MedCAT collecting and storing my names, email, company" - f" or academic institution name, funder and project title, and use" - f" case description. I am aware that MedCAT has been funded through" - f" academic research grants, and therefore funding bodies require its" - f" support team to report wider impact and usage of produced works" - f" with the above information." - )) - - def __init__(self, models, *args, **kwargs): - super().__init__(*args, *kwargs) - self.fields["modelpack"] = forms.ChoiceField(label="Select a model for download", - choices=[( - model.model_name, - f"{model.model_display_name}{' (' + model.model_description + ')' if model.model_description else ''}" - ) for model in models], - widget=forms.RadioSelect()) - - class Meta: - model = Downloader - exclude = ['downloaded_file'] - fields = [ - "first_name", - "last_name", - "email", - "affiliation", - "funder", - "use_case", - ] - labels = { - "first_name": "First Name", - "last_name": "Last Name", - "email": "Email", - "affiliation": "Company or Academic Institution", - "funder": "Funder and Project Title (optional)", - "use_case": "Please describe your use case", - } - widgets = { - "affiliation": forms.TextInput(attrs={"size": 40}), - "funder": forms.TextInput(attrs={"size": 40}), - "use_case": forms.Textarea(attrs={"rows": 5, "cols": 40}), - } diff --git a/webapp/webapp/demo/migrations/0001_initial.py b/webapp/webapp/demo/migrations/0001_initial.py deleted file mode 100644 index 4d63843b9..000000000 --- a/webapp/webapp/demo/migrations/0001_initial.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 2.2.3 on 2019-09-17 11:43 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ] - - operations = [ - migrations.CreateModel( - name='UploadedText', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('text', models.TextField(blank=True, default='')), - ('create_time', models.DateTimeField(auto_now_add=True)), - ], - ), - ] diff --git a/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py b/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py deleted file mode 100644 index d5ce70fd4..000000000 --- a/webapp/webapp/demo/migrations/0002_downloader_medcatmodel.py +++ /dev/null @@ -1,38 +0,0 @@ -# Generated by Django 3.2.11 on 2022-04-06 16:42 - -import django.core.files.storage -from django.db import migrations, models -import django.utils.timezone - - -class Migration(migrations.Migration): - - dependencies = [ - ('demo', '0001_initial'), - ] - - operations = [ - migrations.CreateModel( - name='Downloader', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('first_name', models.CharField(max_length=20)), - ('last_name', models.CharField(max_length=20)), - ('email', models.EmailField(max_length=50)), - ('affiliation', models.CharField(max_length=100)), - ('funder', models.CharField(blank=True, default='', max_length=100)), - ('use_case', models.TextField(max_length=200)), - ('downloaded_file', models.CharField(default=django.utils.timezone.now, max_length=100)), - ], - ), - migrations.CreateModel( - name='MedcatModel', - fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('model_name', models.CharField(max_length=20, unique=True)), - ('model_file', models.FileField(storage=django.core.files.storage.FileSystemStorage(location='/medcat_data'), upload_to='')), - ('model_display_name', models.CharField(max_length=50)), - ('model_description', models.TextField(default=django.utils.timezone.now, max_length=200)), - ], - ), - ] diff --git a/webapp/webapp/demo/migrations/__init__.py b/webapp/webapp/demo/migrations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/demo/models.py b/webapp/webapp/demo/models.py deleted file mode 100644 index da2c8aa5d..000000000 --- a/webapp/webapp/demo/models.py +++ /dev/null @@ -1,31 +0,0 @@ -from django.db import models -from django.core.files.storage import FileSystemStorage - - -MODEL_FS = FileSystemStorage(location="/medcat_data") - - -# Create your models here. -class UploadedText(models.Model): - text = models.TextField(default="", blank=True) - create_time = models.DateTimeField(auto_now_add=True) - - -class Downloader(models.Model): - first_name = models.CharField(max_length=20) - last_name = models.CharField(max_length=20) - email = models.EmailField(max_length=50) - affiliation = models.CharField(max_length=100) - funder = models.CharField(max_length=100, blank=True, default="") - use_case = models.TextField(max_length=200) - downloaded_file = models.CharField(max_length=100) - - def __str__(self): - return f'{self.first_name} - {self.last_name}' - - -class MedcatModel(models.Model): - model_name = models.CharField(max_length=20, unique=True) - model_file = models.FileField(storage=MODEL_FS) - model_display_name = models.CharField(max_length=50) - model_description = models.TextField(max_length=200) diff --git a/webapp/webapp/demo/static/css/annotations.css b/webapp/webapp/demo/static/css/annotations.css deleted file mode 100644 index b0dc44531..000000000 --- a/webapp/webapp/demo/static/css/annotations.css +++ /dev/null @@ -1,110 +0,0 @@ -.textbox{ - margin-left: 5%; - margin-right: 5%; - margin-top: 5%; -} - -.train-annotations{ - margin-top: 3%; - min-height: 80vh; -} - -.annotations{ - padding-top: 1%; - overflow: scroll; - max-height: 82vh; - padding-left: 40px; - margin-left: -20px; - box-shadow: 5px -5px 5px -5px #9476518a; -} - -.green:hover { - color: green; -} - -.red:hover { - color: red; -} - - -div.icons{ - margin-top: -15px; - margin-left: 30px; - font-size: 30px; - position: relative; - float: right; -} - -table.info{ - margin-top: 0%; -} - -.concept-info{ - box-shadow: -5px -5px 5px -5px #aaaaaa; - padding-top: 1%; - overflow: scroll; - max-height: 82vh; - margin-left: 20px; -} - -.row{ - min-height: 80vh; - max-height: 80vh; -} - -mark{ - cursor: pointer; -} - -i{ - cursor: pointer; -} - - -.btns-a{ - margin-bottom:10px; -} - -.btns-a-r{ - margin-left: 10px; - margin-bottom:10px; - position: relative; - float: right; -} - -td.first{ - width: 50px; -} - -td.second{ - width: 100%; -} - -.w100{ - width: 100% -} - -.flt-right{ - float: right; - margin-right: 5px; -} - -.posfed{ - -webkit-animation: fadein 0.3s; - color: green; - float: right; - margin-right: 10px; - border-left: 4px solid black; - padding: 4px; - padding-top: 8px; - padding-bottom: 8px; - position: fixed; - top: 3px; - right: 3px; -} - -/* Safari, Chrome and Opera > 12.1 */ -@-webkit-keyframes fadein { - from { opacity: 0; } - to { opacity: 1; } -} diff --git a/webapp/webapp/demo/static/css/base.css b/webapp/webapp/demo/static/css/base.css deleted file mode 100644 index 0c76c6904..000000000 --- a/webapp/webapp/demo/static/css/base.css +++ /dev/null @@ -1,86 +0,0 @@ -/* Modal related styles */ -.modal-mask { - position: fixed; - z-index: 9998; - top: 0; - left: 0; - width: 100%; - height: 100%; - background-color: rgba(0, 0, 0, .5); - display: table; - transition: opacity .3s ease; -} - -.modal-wrapper { - display: table-cell; - vertical-align: middle; -} - -.modal-container { - width: 400px; - margin: 0px auto; - padding: 20px 30px; - background-color: #fff; - border-radius: 2px; - box-shadow: 0 2px 8px rgba(0, 0, 0, .33); - transition: all .3s ease; - font-family: Helvetica, Arial, sans-serif; -} - -.modal-header div { - width: 100%; -} - -.modal-header h3 { - margin-top: 0; - color: #42b983; - text-align: center; -} - -.modal-header h4 { - display: inline-block; -} - -.modal-header .close { - float: right; -} - -.modal-body { - margin: 20px 0; -} - -.modal-footer { - display: flex !important; - justify-content: center !important; -} - -.modal-default-button { - float: right; -} - -.modal-content { - border: 0; -} - -/* - * The following styles are auto-applied to elements with - * transition="modal" when their visibility is toggled - * by Vue.js. - * - * You can easily play with the modal transition by editing - * these styles. - */ - -.modal-enter { - opacity: 0; -} - -.modal-leave-active { - opacity: 0; -} - -.modal-enter .modal-container, -.modal-leave-active .modal-container { - -webkit-transform: scale(1.1); - transform: scale(1.1); -} diff --git a/webapp/webapp/demo/static/css/home.css b/webapp/webapp/demo/static/css/home.css deleted file mode 100644 index 12ed6666f..000000000 --- a/webapp/webapp/demo/static/css/home.css +++ /dev/null @@ -1,23 +0,0 @@ -p.welcome{ - margin-top: 30px; - margin-left: 40px; - font-size: 16px; -} - -.usecase{ - margin-top: 100px; -} - -.file-upload { - display: inline; -} - -.download { - display: inline; -} - -html, -body { - height: 100%; - margin: 0 -} diff --git a/webapp/webapp/demo/static/image/favicon.ico b/webapp/webapp/demo/static/image/favicon.ico deleted file mode 100644 index 5e8c53af57922a5280844f9001c03a1a1ef8a6de..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4641 zcmeHLd0Z2B79T2#U_t6pC|dMDXC{;P-uK<_<(Fgx z_<0!`m>B>7VCdt`3Bul&V;64zJna5tXOS!Rrh^1|xdA2h=G|%zEXfbA;*K(z}1 zUSlLx7XZYOF!8Mbz{~{z6M6pSKo-W(mwSgJ0I*=O`qBaN@=dV^4^c=M8pd7Egk%y2 zkT2uG4oZm}lLi2ml8HS^U=+kDC1NSURJvGdB$(KNm6R%M}QITNMk`x&W3k#5_<`Bsa zB;rgkSSgwXrsm8F<3kfs$>S7aO-y`<2#a9}EJYCvN1o9fmRl}&0y@P_B`DXJlSpPn zBB)0^7W!B!XG9_)iE@}21S7II1q6Gx_K?r8koSdjmV=yi9bcDuMxRxzCN!ykdM+?b^m;#wX z0-|h@1QfzVxm3s^P7<{wHJ!r5f)$FSpqL|r;?$kx!(n3#2fm2Ubf%M>Nlp|RK>#^H z1V<`^N8s^kAc5k{Ct>47q0yWKEaI$c-ZAAZLQt6^QLA!T{WQ}Y498BTjx)p}%~Z-k z1p=#k&&6`Q`6kAAVoKE*O%F2>nE*`y6|k!i3yAHiUO%{i@$SO8yW=P%vNO(x3qm5P zy5($H#F>DT0ZiB!GXO<{QX$M|6UTw$gxMhp(Xf<_*<}z&6avLDgiK+Q=u9$c)3^$a z!0dRD6hT2L1dr86^r!1H8krV+b{in1nF9ojhxy}fjnSrqSVZ-<6pQ{}Hg#R6XJ@h% zweqeWD|OlEbdV0xAc(-=gA{@z9pn)hj@Zpn;7o&QWExcfV$0!wS+>z)P94JW!G3pW zS)&v0^KL`_xyasuPiAjSwbm{gJsI$25SF-!ZvdtfpH88ZogsoVoysSW)lVReMryY%oo|)Uv?F5@+{FspD}pSC!GVB#^J#n^Sl31Y=E@)Q`5YzU%_++kf6k zZ9bdYn0NWUfdAaf&DQ2AB$Y^p+cp)y9}&@HwXuG4`}2a7%Fcq$-d_|tz2eS-p@PmK z*6n^)=jMMMEwrG^ujvI{J?t$0Dc~~NeMk7Aeb|6BK7RP~q0ZZt>`#Y#0zj~>cVuMb zOxJQ^)YX?qj;s!)EpyrKSg5#g;X=x>M?aQ}n<$~m-sTJ3s;VlJJ9qE88`FZHKYq*) z4hSHfPkJ@j`J(^Pah{=Kacn%h8@J~{bMv6DuP-~Fft=bVfBjqc%GIm4(M{ivva_?> zzKTj6;@>jcNxqu$_Vr4eD~6|SzOShv zEiiFNO0COtDt>TL^1bP*-Mg<}cTG)BPEPdjsB$%Z%Y{G67#ti7wxb@J!z(|1`iYNO zkvu8ssx5)=79IMvZmFHvYtO+;jeWed-;&=9Tb~OnYv0X4el0%)M9uSoKJHNcYFZdG zZ?BE9AD8PSKN(SuUy}mZ>Dgv)%P>`2 zqoN*K7_ayB-4}i#W-BaVv9<;IP5+XG6tM<<|Q{8u@(bzYi7e{>S zd!EyD!ehU~!R!4mK25k4xtGV(_wZQ48DiqlW#=qDiGS=>^T9_4qJ=`C83{46Wl1wX zNEz_(1bPDdK3t%C-};PdUJY)~rcIj+2U}52RGCgIc>Xtwi2{=b~9USAypZb z{UOX{s1EjB{?TCjnZrFNN=pau4Z9ZBBp*m}adCOJ$}iu$J#poMH~^Qt(=je@NdMz{ z*ktcYTie<~I~PCC>_Vjv9_g%qdnl<3%x`@_QqB#0S#a!1^<~x4RqYA>W*{-Suqs zLZ2Z^{27yY<&h*(%0idS);$gVWr#&`4nFJNI@6|8Etl4q8+~OU#9cdFQ*9pn&F4jEdxaVL+KL6V?sCJrXGdFGn^nxNw%mPfIiBnM-zy@E zhIYHHAY(iKuxZ+721%>@3r6(5Z8~?d=pnTeGV8Tmv?9wUDAZfmuX^C-$l4NO+O_=D z4zMAz94;*>DT&Iq%8V}PHou}W{`qZR%_00{{m^u;{-11h!+H#*g%QQY#jpNyIyzKP zOWaorVLxk~YqRTLB33}+o8&U#?PdC}bT=5D4K$qF5pI8U*Dl>OyWu6Z^TW?VmfOC* z+obf)EOjWf%C65P-zeNuSse5H#GQU0uZlWjKYe_ZgNgkR-Ri0J4&N8%O2@{_Zz8@H zTW_Q>V?Jw6_|3D!B5Qxqo|cEL+e{OBn>TEXjFbc{t?O&OawFZyf^sJGTt!59xJZ@Z z|NG~$Jr%Bxw%9DPXWn70$Z0J!7=AH4Y?goUio3h}smjXZ2!d^aAF{HHJfUyShZjS7 z`sA2za>94&tzb*TN6h-TU(oI~yDfgw*GKETecV}iyTD0UpHEhn&2NsiNPQaU;?x|M z(A(R4wA@~|R(E;9ExlsPSfgmR$3c+$osGlC$g_5jIe?;0H|=@=&KMW{W1wGf)vJ@- zp$9E3SB-9ib{%z9Efpm@zwTWF)pwSvh6dVh2)1s$cIyjge?MOy?*x7Bsx22O!T00u z)dl9Z)!^^zG<}eD59YLq==ji;%W`+hRkzJQiC?d*${dRh`0G~y>qNe) z(Qg%+71&jNuO!CI^|%=tSZ$GAdbgn=VD}={TKBF5;52DV`N+l_x;75xwT#j)``
=bMn!EYzBEw<5=;1B<(!N5he@{*_a{Obrce!Wy t(q~)5OD - - - CAT Trainer - - - - - - {% load static %} - - - {% block style %} - {% endblock %} - - - - {% block body %} - {% endblock %} - - - - - {% block script %} - {% endblock %} - - diff --git a/webapp/webapp/demo/templates/train_annotations.html b/webapp/webapp/demo/templates/train_annotations.html deleted file mode 100644 index 25677cd21..000000000 --- a/webapp/webapp/demo/templates/train_annotations.html +++ /dev/null @@ -1,147 +0,0 @@ -{% extends 'base.html' %} -{% load static %} - -{% block style %} - - - - -{% endblock %} - -{% block body %} - - {% if not doc_html %} -
-
- {% csrf_token %} -
- - - -
- -
-
-
-
-
Disclaimer
-

This software is intended solely for the testing purposes and non-commercial use. THE SOFTWARE IS PROVIDED "AS IS", - WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. -

contact@cogstack.com for more information.

-

-
-
Sample text
-
-Description: Intracerebral hemorrhage (very acute clinical changes occurred immediately).
-CC: Left hand numbness on presentation; then developed lethargy later that day.
-
-HX: On the day of presentation, this 72 y/o RHM suddenly developed generalized weakness and lightheadedness, and could not rise from a chair. Four hours later he experienced sudden left hand numbness lasting two hours. There were no other associated symptoms except for the generalized weakness and lightheadedness. He denied vertigo.
-
-He had been experiencing falling spells without associated LOC up to several times a month for the past year.
-
-MEDS: procardia SR, Lasix, Ecotrin, KCL, Digoxin, Colace, Coumadin.
-
-PMH: 1)8/92 evaluation for presyncope (Echocardiogram showed: AV fibrosis/calcification, AV stenosis/insufficiency, MV stenosis with annular calcification and regurgitation, moderate TR, Decreased LV systolic function, severe LAE. MRI brain: focal areas of increased T2 signal in the left cerebellum and in the brainstem probably representing microvascular ischemic disease. IVG (MUGA scan)revealed: global hypokinesis of the LV and biventricular dysfunction, RV ejection Fx 45% and LV ejection Fx 39%. He was subsequently placed on coumadin severe valvular heart disease), 2)HTN, 3)Rheumatic fever and heart disease, 4)COPD, 5)ETOH abuse, 6)colonic polyps, 7)CAD, 8)CHF, 9)Appendectomy, 10)Junctional tachycardia.
-            
-
-
-
Please note this is a limited version of MedCAT and it is not trained or validated by clinicans.
-
- - {% else %} -
-
- -
-
- {{ doc_html|safe }} -
-
-
-
[[selected_concept.pretty_name]]
- - - - - - -
- [[name]] - - [[value]] -
-
-
-
Create a new Concept
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
- Name - - -
- CUI - - -
- TUI - - -
- Source Value - - -
- Synonyms - - -
- Context - - -
-
-
-
-
- {% endif %} - - - - - - -{% endblock %} diff --git a/webapp/webapp/demo/templates/umls_user_validation.html b/webapp/webapp/demo/templates/umls_user_validation.html deleted file mode 100644 index 7ceb50f8e..000000000 --- a/webapp/webapp/demo/templates/umls_user_validation.html +++ /dev/null @@ -1,67 +0,0 @@ -{% extends 'base.html' %} -{% load static %} - -{% block style %} - - - - -{% endblock %} - -{% block body %} - -
-
{{ message }}
-
    - {% if downloader_form.non_field_errors %} -
  • {{ downloader_form.non_field_errors }}
  • - {% endif %} - {% for field in downloader_form %} - {% if field.errors %} -
  • - {{ field.label }} -
      - {% for error in field.errors %} -
    • {{ error }}
    • - {% endfor %} -
    -
  • - {% endif %} - {% endfor %} -
-{% if is_valid %} -
For any update on your previous information, contact contact@cogstack.org.
-
- {% csrf_token %} - {% for field in downloader_form.visible_fields %} - {% if field.name != 'consent' and field.name != 'modelpack' %} -
-
- {{ field }} -
- {% endif %} - {% endfor %} -
- {{ downloader_form.modelpack.label }}:
- {% for radio in downloader_form.modelpack %} - {{ radio }}
- {% endfor %} -
- {{ downloader_form.consent }} {{ downloader_form.consent.label }}

- -
-{% endif %} -
- - - - - - - -{% endblock %} diff --git a/webapp/webapp/demo/tests.py b/webapp/webapp/demo/tests.py deleted file mode 100644 index 7ce503c2d..000000000 --- a/webapp/webapp/demo/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/webapp/webapp/demo/urls.py b/webapp/webapp/demo/urls.py deleted file mode 100644 index 8919757d0..000000000 --- a/webapp/webapp/demo/urls.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.contrib import admin -from django.urls import path -from .views import * - -urlpatterns = [ - path('', show_annotations, name='train_annotations'), - path('auth-callback', validate_umls_user, name='validate-umls-user'), - path('download-model', download_model, name="download-model") -] diff --git a/webapp/webapp/demo/views.py b/webapp/webapp/demo/views.py deleted file mode 100644 index 58d6bf6a7..000000000 --- a/webapp/webapp/demo/views.py +++ /dev/null @@ -1,129 +0,0 @@ -import sys -sys.path.insert(0, '/home/ubuntu/projects/MedCAT/') -import os -import json -from django.shortcuts import render -from django.http import StreamingHttpResponse, HttpResponse -from wsgiref.util import FileWrapper -from medcat.cat import CAT -from medcat.cdb import CDB -from medcat.utils.helpers import doc2html -from medcat.vocab import Vocab -from urllib.request import urlretrieve, urlopen -from urllib.error import HTTPError -#from medcat.meta_cat import MetaCAT -from .models import * -from .forms import DownloaderForm - -AUTH_CALLBACK_SERVICE = 'https://medcat.rosalind.kcl.ac.uk/auth-callback' -VALIDATION_BASE_URL = 'https://uts-ws.nlm.nih.gov/rest/isValidServiceValidate' -VALIDATION_LOGIN_URL = f'https://uts.nlm.nih.gov/uts/login?service={AUTH_CALLBACK_SERVICE}' - -model_pack_path = os.getenv('MODEL_PACK_PATH', 'models/medmen_wstatus_2021_oct.zip') - -try: - cat = CAT.load_model_pack(model_pack_path) -except Exception as e: - print(str(e)) - -def get_html_and_json(text): - doc = cat(text) - - a = json.loads(cat.get_json(text)) - for id, ent in a['annotations'].items(): - new_ent = {} - for key in ent.keys(): - if key == 'pretty_name': - new_ent['Pretty Name'] = ent[key] - if key == 'icd10': - icd10 = ent.get('icd10', []) - new_ent['ICD-10 Code'] = icd10[-1] if icd10 else '-' - if key == 'cui': - new_ent['Identifier'] = ent[key] - if key == 'types': - new_ent['Type'] = ", ".join(ent[key]) - if key == 'acc': - new_ent['Confidence Score'] = ent[key] - if key == 'start': - new_ent['Start Index'] = ent[key] - if key == 'end': - new_ent['End Index'] = ent[key] - if key == 'id': - new_ent['id'] = ent[key] - if key == 'meta_anns': - meta_anns = ent.get("meta_anns", {}) - if meta_anns: - for meta_ann in meta_anns.keys(): - new_ent[meta_ann] = meta_anns[meta_ann]['value'] - - a['annotations'][id] = new_ent - - doc_json = json.dumps(a) - uploaded_text = UploadedText() - uploaded_text.text = len(str(text))#str(text) no saving of text anymore - uploaded_text.save() - - return doc2html(doc), doc_json - - -def show_annotations(request): - context = {} - context['doc_json'] = '{"msg": "No documents yet"}' - - if request.POST and 'text' in request.POST: - doc_html, doc_json = get_html_and_json(request.POST['text']) - - context['doc_html'] = doc_html - context['doc_json'] = doc_json - context['text'] = request.POST['text'] - return render(request, 'train_annotations.html', context=context) - - -def validate_umls_user(request): - ticket = request.GET.get('ticket', '') - validate_url = f'{VALIDATION_BASE_URL}?service={AUTH_CALLBACK_SERVICE}&ticket={ticket}' - try: - is_valid = urlopen(validate_url, timeout=10).read().decode('utf-8') - context = { - 'is_valid': is_valid == 'true' - } - if is_valid == 'true': - context['message'] = 'License verified! Please fill in the following form before downloading models.' - context['downloader_form'] = DownloaderForm(MedcatModel.objects.all()) - else: - context['message'] = f'License not found. Please request or renew your UMLS Metathesaurus License. If you think you have got the license, try {VALIDATION_LOGIN_URL} again.' - except HTTPError: - context = { - 'is_valid': False, - 'message': 'Something went wrong. Please try again.' - } - finally: - return render(request, 'umls_user_validation.html', context=context) - - -def download_model(request): - if request.method == 'POST': - downloader_form = DownloaderForm(MedcatModel.objects.all(), request.POST) - if downloader_form.is_valid(): - mp_name = downloader_form.cleaned_data['modelpack'] - model = MedcatModel.objects.get(model_name=mp_name) - if model is not None: - mp_path = model.model_file.path - else: - return HttpResponse(f'Error: Unknown model "{downloader_form.modelpack}"') - resp = StreamingHttpResponse(FileWrapper(open(mp_path, 'rb'))) - resp['Content-Type'] = 'application/zip' - resp['Content-Length'] = os.path.getsize(mp_path) - resp['Content-Disposition'] = f'attachment; filename={os.path.basename(mp_path)}' - downloader_form.instance.downloaded_file = os.path.basename(mp_path) - downloader_form.save() - return resp - else: - context = { - 'is_valid': True, - 'downloader_form': downloader_form, - 'message': 'All non-optional fields must be filled out:' - } - return render(request, 'umls_user_validation.html', context=context) - else: - return HttpResponse('Erorr: Unknown HTTP method.') diff --git a/webapp/webapp/etc/cron.d/db-backup-cron b/webapp/webapp/etc/cron.d/db-backup-cron deleted file mode 100644 index 6c8fe22b1..000000000 --- a/webapp/webapp/etc/cron.d/db-backup-cron +++ /dev/null @@ -1 +0,0 @@ -* * * * * /usr/local/bin/python /webapp/manage.py runcrons >/dev/null 2>&1 diff --git a/webapp/webapp/manage.py b/webapp/webapp/manage.py deleted file mode 100755 index cdef512a0..000000000 --- a/webapp/webapp/manage.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -"""Django's command-line utility for administrative tasks.""" -import os -import sys - - -def main(): - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'webapp.settings') - try: - from django.core.management import execute_from_command_line - except ImportError as exc: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) from exc - execute_from_command_line(sys.argv) - - -if __name__ == '__main__': - main() diff --git a/webapp/webapp/models/.keep b/webapp/webapp/models/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/requirements.txt b/webapp/webapp/requirements.txt deleted file mode 100644 index c525cf0e4..000000000 --- a/webapp/webapp/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -Django==3.2.25 -django-dbbackup==4.0.0b0 -django-storages[boto3]==1.12.3 -django-cron==0.5.1 -medcat==1.2.7 -urllib3==1.26.18 diff --git a/webapp/webapp/webapp/__init__.py b/webapp/webapp/webapp/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/webapp/webapp/webapp/settings.py b/webapp/webapp/webapp/settings.py deleted file mode 100644 index cd68965dc..000000000 --- a/webapp/webapp/webapp/settings.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -Django settings for webapp project. - -Generated by 'django-admin startproject' using Django 2.2.3. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/2.2/ref/settings/ -""" - -import os - -# Build paths inside the project like this: os.path.join(BASE_DIR, ...) -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/2.2/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = '2y3*wm_n52xyis_kaup96+5^*^*$h^!!na-$n%l9ppc0rhfea$' - -# SECURITY WARNING: don't run with debug turned on in production! -DEBUG = False - -ALLOWED_HOSTS = ['*'] - - -# Application definition - -INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'dbbackup', - 'django_cron', - 'demo', -] - -MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', -] - -ROOT_URLCONF = 'webapp.urls' - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }, -] - -WSGI_APPLICATION = 'webapp.wsgi.application' - - -# Database -# https://docs.djangoproject.com/en/2.2/ref/settings/#databases - -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db/db.sqlite3'), - } -} - -DB_BACKUP_ON_S3 = os.environ.get('DB_BACKUP_ON_S3', 'False') -DB_BACKUP_LOCATION = os.environ.get('DB_BACKUP_LOCATION', 'demo-db-backup/') -if DB_BACKUP_ON_S3 == "False": - DBBACKUP_STORAGE = 'django.core.files.storage.FileSystemStorage' - DBBACKUP_STORAGE_OPTIONS = {'location': f'/tmp/{DB_BACKUP_LOCATION}'} -else: - DBBACKUP_STORAGE = 'storages.backends.s3boto3.S3Boto3Storage' - DBBACKUP_STORAGE_OPTIONS = { - 'region_name': 'eu-west-2', - 'access_key': os.environ.get('ACCESS_KEY', ''), - 'secret_key': os.environ.get('SECRET_KEY', ''), - 'bucket_name': os.environ.get('BUCKET_NAME', ''), - 'default_acl': 'bucket-owner-full-control', - 'location': DB_BACKUP_LOCATION, - } - -CRON_CLASSES = [ - 'demo.db_backup.DbBackup', -] -DJANGO_CRON_DELETE_LOGS_OLDER_THAN = int(os.environ.get('DELETE_LOGS_OLDER_THAN', '7')) - -# Password validation -# https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - -# Internationalization -# https://docs.djangoproject.com/en/2.2/topics/i18n/ - -LANGUAGE_CODE = 'en-us' - -TIME_ZONE = 'UTC' - -USE_I18N = True - -USE_L10N = True - -USE_TZ = True - - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/2.2/howto/static-files/ - -STATIC_URL = '/static/' -STATIC_ROOT = os.path.join(BASE_DIR, 'demo', 'static') -MEDIA_URL = '/media/' -MEDIA_ROOT = os.path.join(BASE_DIR, 'data') diff --git a/webapp/webapp/webapp/urls.py b/webapp/webapp/webapp/urls.py deleted file mode 100644 index 703574edf..000000000 --- a/webapp/webapp/webapp/urls.py +++ /dev/null @@ -1,26 +0,0 @@ -"""webapp URL Configuration - -The `urlpatterns` list routes URLs to views. For more information please see: - https://docs.djangoproject.com/en/2.2/topics/http/urls/ -Examples: -Function views - 1. Add an import: from my_app import views - 2. Add a URL to urlpatterns: path('', views.home, name='home') -Class-based views - 1. Add an import: from other_app.views import Home - 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') -Including another URLconf - 1. Import the include() function: from django.urls import include, path - 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) -""" -from django.contrib import admin -from django.urls import path, include, re_path -from django.conf import settings -from django.views.static import serve - -urlpatterns = [ - path('admin/', admin.site.urls), - path('', include('demo.urls')), - re_path(r'^static/(?P.*)$', serve,{'document_root': settings.STATIC_ROOT}), - re_path(r'^media/(?P.*)$', serve,{'document_root': settings.MEDIA_ROOT}), -] diff --git a/webapp/webapp/webapp/wsgi.py b/webapp/webapp/webapp/wsgi.py deleted file mode 100644 index 420a2338a..000000000 --- a/webapp/webapp/webapp/wsgi.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -WSGI config for webapp project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'webapp.settings') - -application = get_wsgi_application()