diff --git a/tests/archive_tests/test_ner_archive.py b/tests/archive_tests/test_ner_archive.py index 439413073..d41ccd0c7 100644 --- a/tests/archive_tests/test_ner_archive.py +++ b/tests/archive_tests/test_ner_archive.py @@ -1,5 +1,4 @@ import logging -import os import unittest import numpy as np from timeit import default_timer as timer @@ -15,6 +14,8 @@ from medcat.linking.context_based_linker import Linker from medcat.config import Config +from ..helper import VocabDownloader + class NerArchiveTests(unittest.TestCase): @@ -35,12 +36,9 @@ def setUp(self) -> None: # Check #assert cdb.cui2names == {'S-229004': {'movar', 'movarvirus', 'movarviruses'}, 'S-229005': {'cdb'}} - self.vocab_path = "./tmp_vocab.dat" - if not os.path.exists(self.vocab_path): - import requests - tmp = requests.get("https://medcat.rosalind.kcl.ac.uk/media/vocab.dat") - with open(self.vocab_path, 'wb') as f: - f.write(tmp.content) + downloader = VocabDownloader() + self.vocab_path = downloader.vocab_path + downloader.check_or_download() vocab = Vocab.load(self.vocab_path) # Make the pipeline diff --git a/tests/helper.py b/tests/helper.py index 1cb284ad7..483a6d1ad 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -1,6 +1,76 @@ +import os +import requests import unittest +import numpy as np + +from medcat.vocab import Vocab + class AsyncMock(unittest.mock.MagicMock): async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) + + +ERROR_503 = b""" + +503 Service Unavailable + +

Service Unavailable

+

The server is temporarily unable to service your +request due to maintenance downtime or capacity +problems. Please try again later.

+ +""" + +SIMPLE_WORDS = """house 34444 0.3232 0.123213 1.231231 +dog 14444 0.76762 0.76767 1.45454""" + + +def generate_simple_vocab(): + v = Vocab() + # v.add_words() + for line in SIMPLE_WORDS.split('\n'): + parts = line.split("\t") + word = parts[0] + cnt = int(parts[1].strip()) + vec = None + if len(parts) == 3: + vec = np.array([float(x) for x in parts[2].strip().split(" ")]) + + v.add_word(word, cnt, vec, replace=True) + v.make_unigram_table() + return v + + +class VocabDownloader: + url = 'https://medcat.rosalind.kcl.ac.uk/media/vocab.dat' + vocab_path = "./tmp_vocab.dat" + _has_simple = False + + def is_valid(self): + with open(self.vocab_path, 'rb') as f: + content = f.read() + if content == ERROR_503: + return False + v = Vocab.load(self.vocab_path) + if len(v.vocab) == 2: # simple one + self._has_simple = True + return False + return True + + def check_or_download(self): + if os.path.exists(self.vocab_path) and self.is_valid(): + return + tmp = requests.get(self.url) + if tmp.content == ERROR_503: + print('Rosalind server unavailable') + if self._has_simple: + print('Local simple vocab already present') + return + print('Generating local simple vocab instead') + v = generate_simple_vocab() + v.save(self.vocab_path) + return + with open(self.vocab_path, 'wb') as f: + f.write(tmp.content) diff --git a/tests/medmentions/make_cdb.py b/tests/medmentions/make_cdb.py index 52929b31f..feb8629d2 100644 --- a/tests/medmentions/make_cdb.py +++ b/tests/medmentions/make_cdb.py @@ -3,7 +3,9 @@ from functools import partial import numpy as np import logging -import os + +from ..helper import VocabDownloader + config = Config() config.general['log_level'] = logging.INFO @@ -21,12 +23,9 @@ from medcat.cdb import CDB from medcat.cat import CAT -vocab_path = "./tmp_vocab.dat" -if not os.path.exists(vocab_path): - import requests - tmp = requests.get("https://s3-eu-west-1.amazonaws.com/zkcl/vocab.dat") - with open(vocab_path, 'wb') as f: - f.write(tmp.content) +downloader = VocabDownloader() +vocab_path = downloader.vocab_path +downloader.check_or_download() config = Config() cdb = CDB.load("./tmp_cdb.dat", config=config) diff --git a/tests/test_ner.py b/tests/test_ner.py index 1ae6e375d..b5b185842 100644 --- a/tests/test_ner.py +++ b/tests/test_ner.py @@ -1,6 +1,4 @@ import logging -import os -import requests import unittest from spacy.lang.en import English from medcat.preprocessing.tokenizers import spacy_split_all @@ -14,6 +12,8 @@ from medcat.config import Config from medcat.cdb import CDB +from .helper import VocabDownloader + class A_NERTests(unittest.TestCase): @classmethod @@ -25,11 +25,9 @@ def setUpClass(cls): cls.cdb = CDB(config=cls.config) print("Set up Vocab") - vocab_path = "./tmp_vocab.dat" - if not os.path.exists(vocab_path): - tmp = requests.get("https://medcat.rosalind.kcl.ac.uk/media/vocab.dat") - with open(vocab_path, 'wb') as f: - f.write(tmp.content) + downloader = VocabDownloader() + vocab_path = downloader.vocab_path + downloader.check_or_download() cls.vocab = Vocab.load(vocab_path) diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 7f5bd2ece..e6da42898 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -1,7 +1,5 @@ import unittest import logging -import os -import requests from spacy.language import Language from medcat.cdb import CDB from medcat.vocab import Vocab @@ -17,6 +15,8 @@ from transformers import AutoTokenizer +from .helper import VocabDownloader + class PipeTests(unittest.TestCase): @@ -30,11 +30,9 @@ def setUpClass(cls) -> None: cls.config.linking['disamb_length_limit'] = 2 cls.cdb = CDB(config=cls.config) - vocab_path = "./tmp_vocab.dat" - if not os.path.exists(vocab_path): - tmp = requests.get("https://medcat.rosalind.kcl.ac.uk/media/vocab.dat") - with open(vocab_path, 'wb') as f: - f.write(tmp.content) + downloader = VocabDownloader() + vocab_path = downloader.vocab_path + downloader.check_or_download() cls.vocab = Vocab.load(vocab_path) cls.spell_checker = BasicSpellChecker(cdb_vocab=cls.cdb.vocab, config=cls.config, data_vocab=cls.vocab)