Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix export_to_brat when there are spaces before new lines #211

Merged
merged 4 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
- name: Install dependencies
run: |
pip install --upgrade pip
pip install -e '.[dev,docs,setup]'
pip install -e '.[dev,setup]'
- name: Test with Pytest on Python ${{ matrix.python-version }}
env:
Expand Down
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Pending

### Fixed
- `export_to_brat` issue with spans of entities on multiple lines.

## v0.8.1 (2023-05-31)

Fix release to allow installation from source
Expand Down
13 changes: 6 additions & 7 deletions edsnlp/connectors/brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,17 @@ def export_to_brat(doc, txt_filename, overwrite_txt=False, overwrite_ann=False):
):
idx = fragment["begin"]
entity_text = doc["text"][fragment["begin"] : fragment["end"]]
for part in entity_text.split("\n"):
begin = idx
end = idx + len(part)
idx = end + 1
if begin != end:
spans.append((begin, end))
# eg: "mon entité \n est problématique"
for match in re.finditer(
r"\s*(.+?)(?:( *\n+)+ *|$)", entity_text, flags=re.DOTALL
):
spans.append((idx + match.start(1), idx + match.end(1)))
print(
"{}\t{} {}\t{}".format(
brat_entity_id,
str(entity["label"]),
";".join(" ".join(map(str, span)) for span in spans),
entity_text.replace("\n", " "),
" ".join(doc["text"][begin:end] for begin, end in spans),
),
file=f,
)
Expand Down
6 changes: 3 additions & 3 deletions edsnlp/pipelines/ner/umls/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def get_patterns(config: Dict[str, Any]) -> Dict[str, List[str]]:

path, module, filename = get_path(config)

if path.exists():
if path.exists(): # pragma: no cover
print(f"Loading {filename} from {module.base}")
return module.load_pickle(name=filename)
else:
else: # pragma: no cover
patterns = download_and_agg_umls(config)
module.dump_pickle(name=filename, obj=patterns)
print(f"Saved patterns into {module.base / filename}")
Expand Down Expand Up @@ -108,7 +108,7 @@ def download_and_agg_umls(config) -> Dict[str, List[str]]:
"""

api_key = os.getenv("UMLS_API_KEY")
if not api_key:
if not api_key: # pragma: no cover
warnings.warn(
"You need to define UMLS_API_KEY to download the UMLS. "
"Get a key by creating an account at "
Expand Down
25 changes: 21 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@ dependencies = [
"decorator",
"loguru",
"pendulum>=2.1.2",
"pydantic>=1.8.2,<2.0.0",
"pydantic>=1.10.2,<2.0.0",
"pysimstring>=1.2.1",
"regex",
"rich>=12.0.0",
"scikit-learn>=1.0.0",
"spacy>=3.1,<3.5.0",
"thinc>=8.0.14,<8.1.11",
"thinc>=8.1.10",
"tqdm",
"umls-downloader>=0.1.1",
"numpy>=1.15.0,<1.23.2; python_version<'3.8'",
"numpy>=1.15.0; python_version>='3.8'",
"pandas>=1.1.0,<2.0.0; python_version<'3.8'",
"pandas>=1.4.0,<2.0.0; python_version>='3.8'",
"typing_extensions<4.6.0" # https://github.com/explosion/spaCy/issues/12659
"typing_extensions<4.6.0,>=4.0.0; python_version>='3.8'", # https://github.com/explosion/spaCy/issues/12659
"typing_extensions>=4.0.0; python_version<'3.8'"
]
[project.optional-dependencies]
dev = [
Expand All @@ -39,7 +40,7 @@ dev = [
"pytest>=7.1.0,<8.0.0",
"pytest-cov>=3.0.0,<4.0.0",
"pytest-html>=3.1.1,<4.0.0",
"torch>=1.0.0,<1.13.0",
"torch>=1.0.0",
]
setup = [
"mlconjug3<3.9.0",
Expand Down Expand Up @@ -197,6 +198,22 @@ omit-covered-files = false
# generate-badge = "."
# badge-format = "svg"


[tool.coverage]
exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
"if typing.TYPE_CHECKING:",
"@overload",
"pragma: no cover",
"raise AssertionError",
"raise NotImplementedError",
"def __repr__",
"Span.set_extension.*",
"Doc.set_extension.*",
"Token.set_extension.*",
]

[tool.cibuildwheel]
skip = [
"*p36-*", # Skip Python 3.6
Expand Down
2 changes: 1 addition & 1 deletion tests/connectors/test_brat.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_brat(
A1 etat T1 test
T2 localisation 39 57 dans le bras droit
T3 anatomie 47 57 bras droit
T4 pathologie 75 84;85 98 problème de locomotion
T4 pathologie 75 83;85 98 problème de locomotion
A2 assertion T4 absent
T5 pathologie 114 117 AVC
A3 etat T5 passé
Expand Down
2 changes: 2 additions & 0 deletions tests/pipelines/ner/test_umls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
pattern_config = {"lang": ["FRE"], "sources": ["MSHFRE"]}


@pytest.mark.skipif(not os.getenv("UMLS_API_KEY"), reason="No UMLS_API_KEY given")
def test_get_patterns():

path, _, _ = get_path(pattern_config)
Expand All @@ -41,6 +42,7 @@ def test_get_patterns():
assert len(patterns) == 48587


@pytest.mark.skipif(not os.getenv("UMLS_API_KEY"), reason="No UMLS_API_KEY given")
def test_add_pipe(blank_nlp: Language):
path, _, _ = get_path(pattern_config)
if not path.exists():
Expand Down
Loading