From 33e14cab529a3f953113dd289f9981c29ab5530b Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 27 Jun 2023 22:26:30 +0200 Subject: [PATCH 01/12] added `pipe()` to spaCy integration --- span_marker/spacy_integration.py | 56 ++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index ad01cd28..50800b71 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -1,9 +1,11 @@ import os -from typing import Any, Optional, Union +from typing import Optional, Union import torch from datasets import Dataset -from spacy.tokens import Doc, Span +from spacy.tokens import Doc +from spacy.util import minibatch +import types from span_marker.modeling import SpanMarkerModel @@ -71,21 +73,28 @@ def __init__( self.model.to("cuda") self.batch_size = batch_size + @staticmethod + def convert_inputs_to_dataset(inputs): + inputs = Dataset.from_dict( + { + "tokens": inputs, + "document_id": [0] * len(inputs), + "sentence_id": range(len(inputs)), + } + ) + return inputs + + def __call__(self, doc: Doc) -> Doc: """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" sents = list(doc.sents) inputs = [[token.text if not token.is_space else "" for token in sent] for sent in sents] + # use document-level context in the inference if the model was also trained that way if self.model.config.trained_with_document_context: - inputs = Dataset.from_dict( - { - "tokens": inputs, - "document_id": [0] * len(inputs), - "sentence_id": range(len(inputs)), - } - ) - outputs = [] + inputs = self.convert_inputs_to_dataset(inputs) + outputs = [] entities_list = self.model.predict(inputs, batch_size=self.batch_size) for sentence, entities in zip(sents, entities_list): for entity in entities: @@ -97,3 +106,30 @@ def __call__(self, doc: Doc) -> Doc: doc.set_ents(outputs) return doc + + def pipe(self, stream, batch_size=128, include_sent=None): + """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" + if isinstance(stream, str): + stream = [stream] + + if not isinstance(stream, types.GeneratorType): + stream = self.nlp.pipe(stream, batch_size=batch_size) + + for docs in minibatch(stream, size=batch_size): + inputs = [[token.text if not token.is_space else "" for token in doc] for doc in docs] + + # use document-level context in the inference if the model was also trained that way + if self.model.config.trained_with_document_context: + inputs = self.convert_inputs_to_dataset(inputs) + + entities_list = self.model.predict(inputs, batch_size=self.batch_size) + for doc, entities in zip(docs, entities_list): + outputs = [] + for entity in entities: + start = entity["word_start_index"] + end = entity["word_end_index"] + span = doc[start:end] + span.label_ = entity["label"] + outputs.append(span) + doc.set_ents(outputs) + yield doc From 37fb4318d9ece7c273f41ef7c59fbc541022cb26 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 27 Jun 2023 22:28:10 +0200 Subject: [PATCH 02/12] added spaCy `.pipe()` integration tests --- tests/test_spacy_integration.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_spacy_integration.py b/tests/test_spacy_integration.py index f835bb45..0c4ae1c7 100644 --- a/tests/test_spacy_integration.py +++ b/tests/test_spacy_integration.py @@ -27,3 +27,32 @@ def test_span_marker_as_spacy_pipeline_component(): ("Atlantic", "LOC"), ("Paris", "LOC"), ] + +def test_span_marker_as_spacy_pipeline_component_pipe(): + nlp = spacy.load("en_core_web_sm", disable=["ner"]) + batch_size = 2 + wrapper = nlp.add_pipe( + "span_marker", config={"model": "tomaarsen/span-marker-bert-tiny-conll03", "batch_size": batch_size} + ) + assert wrapper.batch_size == batch_size + + docs = nlp.pipe(["Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris."]) + doc = list(docs)[0] + assert [(span.text, span.label_) for span in doc.ents] == [ + ("Amelia Earhart", "PER"), + ("Lockheed Vega", "ORG"), + ("Atlantic", "LOC"), + ("Paris", "LOC"), + ] + + # Override a setting that modifies how inference is performed, + # should not have any impact with just one sentence, i.e. it should still work. + wrapper.model.config.trained_with_document_context = True + docs = nlp.pipe(["Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris."]) + doc = list(docs)[0] + assert [(span.text, span.label_) for span in doc.ents] == [ + ("Amelia Earhart", "PER"), + ("Lockheed Vega", "ORG"), + ("Atlantic", "LOC"), + ("Paris", "LOC"), + ] From f34f813e14bd9c84af6da4024273e671d42b21a5 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 12 Jul 2023 23:00:30 +0200 Subject: [PATCH 03/12] chore: avoid overwriting pre-existing entities #17 --- span_marker/spacy_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index 7a804fcb..ec110562 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -4,7 +4,7 @@ import torch from datasets import Dataset from spacy.tokens import Doc -from spacy.util import minibatch +from spacy.util import minibatch, filter_spans import types from span_marker.modeling import SpanMarkerModel @@ -104,7 +104,7 @@ def __call__(self, doc: Doc) -> Doc: span.label_ = entity["label"] outputs.append(span) - doc.set_ents(outputs) + doc.set_ents(filter_spans(list(doc.ents) + outputs)) return doc def pipe(self, stream, batch_size=128, include_sent=None): @@ -131,5 +131,5 @@ def pipe(self, stream, batch_size=128, include_sent=None): span = doc[start:end] span.label_ = entity["label"] outputs.append(span) - doc.set_ents(outputs) + doc.set_ents(filter_spans(list(doc.ents) + outputs)) yield doc From 07118043b47cc62202e739d51627f697c28eb558 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 13 Jul 2023 08:34:43 +0200 Subject: [PATCH 04/12] chore: disable removing NER pipeline by default --- span_marker/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/span_marker/__init__.py b/span_marker/__init__.py index e07278b0..7286df17 100644 --- a/span_marker/__init__.py +++ b/span_marker/__init__.py @@ -40,13 +40,6 @@ def _spacy_span_marker_factory( batch_size: int, device: Optional[Union[str, torch.device]], ) -> SpacySpanMarkerWrapper: - # Remove the existing NER component, if it exists, - # to allow for SpanMarker to act as a drop-in replacement - try: - nlp.remove_pipe("ner") - except ValueError: - # The `ner` pipeline component was not found - pass return SpacySpanMarkerWrapper(model, batch_size=batch_size, device=device) From 4f58b6b6d71ec2c1da9614365c7db5652e9cd520 Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Thu, 13 Jul 2023 08:42:32 +0200 Subject: [PATCH 05/12] chore: added batch size warning --- span_marker/spacy_integration.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index ec110562..bde867b3 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -1,11 +1,12 @@ import os +import types +import warnings from typing import Optional, Union import torch from datasets import Dataset from spacy.tokens import Doc -from spacy.util import minibatch, filter_spans -import types +from spacy.util import filter_spans, minibatch from span_marker.modeling import SpanMarkerModel @@ -107,8 +108,16 @@ def __call__(self, doc: Doc) -> Doc: doc.set_ents(filter_spans(list(doc.ents) + outputs)) return doc - def pipe(self, stream, batch_size=128, include_sent=None): + def pipe(self, stream, batch_size=128): """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" + if batch_size != self.batch_size: + warnings.warn( + ( + f"Using a different spaCy batch size ({batch_size}) than the one used for initialization of SpanMarker ({self.batch_size}).", + "This might lead to sub-optimal inference. Consider using the same batch size for both." + ) + ) + if isinstance(stream, str): stream = [stream] From 12b52f3bc4bec4b78ba1fa91da53bba4ba623cbb Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 14 Jul 2023 10:37:46 +0200 Subject: [PATCH 06/12] chore: added overwrite_entities flag chore: removed warning chore: updated changelog --- CHANGELOG.md | 11 +++++++++++ span_marker/__init__.py | 10 ++++++++++ span_marker/spacy_integration.py | 25 ++++++++++++++----------- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec0159f5..6c82e6c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,17 @@ Types of changes * "Security" in case of vulnerabilities. --> +## [1.2.4] + +### Fixed + +- Fix overwriting spaCy entities by default + +### Added + +- Add `overwrite_entities` parameter to allow for choosing overwriting spaCy entities. +- Add `.pipe()` method to spaCy integration to allow for batched inference. + ## [1.2.3] ### Fixed diff --git a/span_marker/__init__.py b/span_marker/__init__.py index 7286df17..9d03878d 100644 --- a/span_marker/__init__.py +++ b/span_marker/__init__.py @@ -26,6 +26,7 @@ "model": "tomaarsen/span-marker-roberta-large-ontonotes5", "batch_size": 4, "device": None, + "overwrite_entities": False } @Language.factory( @@ -39,7 +40,16 @@ def _spacy_span_marker_factory( model: str, batch_size: int, device: Optional[Union[str, torch.device]], + overwrite_entities: Optional[bool] ) -> SpacySpanMarkerWrapper: + if overwrite_entities: + # Remove the existing NER component, if it exists, + # to allow for SpanMarker to act as a drop-in replacement + try: + nlp.remove_pipe("ner") + except ValueError: + # The `ner` pipeline component was not found + pass return SpacySpanMarkerWrapper(model, batch_size=batch_size, device=device) diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index bde867b3..9152d38f 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -1,6 +1,5 @@ import os import types -import warnings from typing import Optional, Union import torch @@ -56,6 +55,7 @@ def __init__( *args, batch_size: int = 4, device: Optional[Union[str, torch.device]] = None, + overwrite_entities: Optional[bool] = False, **kwargs, ) -> None: """Initialize a SpanMarker wrapper for spaCy. @@ -66,6 +66,7 @@ def __init__( batch_size (int): The number of samples to include per batch. Higher is faster, but requires more memory. Defaults to 4. device (Optional[Union[str, torch.device]]): The device to place the model on. Defaults to None. + overwrite_entities (Optional[bool]): Whether to overwrite the existing entities in the `doc.ents` attribute. """ self.model = SpanMarkerModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if device: @@ -73,6 +74,7 @@ def __init__( elif torch.cuda.is_available(): self.model.to("cuda") self.batch_size = batch_size + self.overwrite_entities = overwrite_entities @staticmethod def convert_inputs_to_dataset(inputs): @@ -105,19 +107,15 @@ def __call__(self, doc: Doc) -> Doc: span.label_ = entity["label"] outputs.append(span) - doc.set_ents(filter_spans(list(doc.ents) + outputs)) + if self.overwrite_entities: + doc.set_ents(outputs) + else: + doc.set_ents(filter_spans(list(doc.ents) + outputs)) + return doc def pipe(self, stream, batch_size=128): """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" - if batch_size != self.batch_size: - warnings.warn( - ( - f"Using a different spaCy batch size ({batch_size}) than the one used for initialization of SpanMarker ({self.batch_size}).", - "This might lead to sub-optimal inference. Consider using the same batch size for both." - ) - ) - if isinstance(stream, str): stream = [stream] @@ -140,5 +138,10 @@ def pipe(self, stream, batch_size=128): span = doc[start:end] span.label_ = entity["label"] outputs.append(span) - doc.set_ents(filter_spans(list(doc.ents) + outputs)) + + if self.overwrite_entities: + doc.set_ents(outputs) + else: + doc.set_ents(filter_spans(list(doc.ents) + outputs)) + yield doc From 57390d4db920d67921f5bcdfec54f9701238d89a Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Fri, 14 Jul 2023 10:41:08 +0200 Subject: [PATCH 07/12] fix: resolved small typo --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c82e6c2..9cc18368 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,8 +23,8 @@ Types of changes ### Added -- Add `overwrite_entities` parameter to allow for choosing overwriting spaCy entities. -- Add `.pipe()` method to spaCy integration to allow for batched inference. +- Added `overwrite_entities` parameter to allow for choosing overwriting spaCy entities. +- Added `.pipe()` method to spaCy integration to allow for batched inference. ## [1.2.3] From f02469e3eef1356fb7f8f14ffe0ea4e1c42b3379 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 24 Aug 2023 11:06:24 +0200 Subject: [PATCH 08/12] Small refactor + formatting * Removed Optional from `overwrite_entities` * Introduce `set_ents` method to prevent duplicate code --- span_marker/__init__.py | 4 ++-- span_marker/spacy_integration.py | 32 ++++++++++++++++---------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/span_marker/__init__.py b/span_marker/__init__.py index 4902357a..0aa38554 100644 --- a/span_marker/__init__.py +++ b/span_marker/__init__.py @@ -26,7 +26,7 @@ "model": "tomaarsen/span-marker-roberta-large-ontonotes5", "batch_size": 4, "device": None, - "overwrite_entities": False + "overwrite_entities": False, } @Language.factory( @@ -40,7 +40,7 @@ def _spacy_span_marker_factory( model: str, batch_size: int, device: Optional[Union[str, torch.device]], - overwrite_entities: Optional[bool] + overwrite_entities: bool, ) -> SpacySpanMarkerWrapper: if overwrite_entities: # Remove the existing NER component, if it exists, diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index 9152d38f..9778e3aa 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -1,10 +1,10 @@ import os import types -from typing import Optional, Union +from typing import List, Optional, Union import torch from datasets import Dataset -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from spacy.util import filter_spans, minibatch from span_marker.modeling import SpanMarkerModel @@ -55,7 +55,7 @@ def __init__( *args, batch_size: int = 4, device: Optional[Union[str, torch.device]] = None, - overwrite_entities: Optional[bool] = False, + overwrite_entities: bool = False, **kwargs, ) -> None: """Initialize a SpanMarker wrapper for spaCy. @@ -66,7 +66,8 @@ def __init__( batch_size (int): The number of samples to include per batch. Higher is faster, but requires more memory. Defaults to 4. device (Optional[Union[str, torch.device]]): The device to place the model on. Defaults to None. - overwrite_entities (Optional[bool]): Whether to overwrite the existing entities in the `doc.ents` attribute. + overwrite_entities (bool): Whether to overwrite the existing entities in the `doc.ents` attribute. + Defaults to False. """ self.model = SpanMarkerModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if device: @@ -87,6 +88,11 @@ def convert_inputs_to_dataset(inputs): ) return inputs + def set_ents(self, doc: Doc, ents: List[Span]): + if self.overwrite_entities: + doc.set_ents(ents) + else: + doc.set_ents(filter_spans(list(doc.ents) + ents)) def __call__(self, doc: Doc) -> Doc: """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model.""" @@ -97,7 +103,7 @@ def __call__(self, doc: Doc) -> Doc: if self.model.config.trained_with_document_context: inputs = self.convert_inputs_to_dataset(inputs) - outputs = [] + ents = [] entities_list = self.model.predict(inputs, batch_size=self.batch_size) for sentence, entities in zip(sents, entities_list): for entity in entities: @@ -105,12 +111,9 @@ def __call__(self, doc: Doc) -> Doc: end = entity["word_end_index"] span = sentence[start:end] span.label_ = entity["label"] - outputs.append(span) + ents.append(span) - if self.overwrite_entities: - doc.set_ents(outputs) - else: - doc.set_ents(filter_spans(list(doc.ents) + outputs)) + self.set_ents(doc, ents) return doc @@ -131,17 +134,14 @@ def pipe(self, stream, batch_size=128): entities_list = self.model.predict(inputs, batch_size=self.batch_size) for doc, entities in zip(docs, entities_list): - outputs = [] + ents = [] for entity in entities: start = entity["word_start_index"] end = entity["word_end_index"] span = doc[start:end] span.label_ = entity["label"] - outputs.append(span) + ents.append(span) - if self.overwrite_entities: - doc.set_ents(outputs) - else: - doc.set_ents(filter_spans(list(doc.ents) + outputs)) + self.set_ents(doc, ents) yield doc From 699659e6e9c47891c3870d91be5513020a2493fb Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 24 Aug 2023 11:07:59 +0200 Subject: [PATCH 09/12] Update changelog --- CHANGELOG.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b77955d0..8bbb1df4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,22 +15,25 @@ Types of changes * "Security" in case of vulnerabilities. --> -## [1.2.5] +## [Unreleased] -### Fixed +### Added -- Allow for immutable `TrainingArguments` from newer `transformers` release. +- Added an `overwrite_entities` parameter to the spaCy pipeline component to allow for overwriting spaCy entities. +- Added `.pipe()` method to spaCy integration to allow for batched inference. -## [1.2.4] +### Changed + +- Stop overwriting spaCy entities by default. + +## [1.2.5] ### Fixed -- Fix overwriting spaCy entities by default +- Allow for immutable `TrainingArguments` from newer `transformers` release. -### Added +## [1.2.4] -- Added `overwrite_entities` parameter to allow for choosing overwriting spaCy entities. -- Added `.pipe()` method to spaCy integration to allow for batched inference. - Resolved broken license information. ## [1.2.3] From 1653b0efc1324f1adf646041f66319b0db31d5c6 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 24 Aug 2023 11:13:28 +0200 Subject: [PATCH 10/12] Update documentation with overwrite_entities --- notebooks/spacy_integration.ipynb | 41 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/notebooks/spacy_integration.ipynb b/notebooks/spacy_integration.ipynb index 9af305bb..4023d308 100644 --- a/notebooks/spacy_integration.ipynb +++ b/notebooks/spacy_integration.ipynb @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -58,7 +58,7 @@ " BCE)" ] }, - "execution_count": 11, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -211,9 +211,9 @@ " PERSON\n", "\n", ", also known as \n", - "\n", + "\n", " Cleopatra the Great\n", - " PERSON\n", + " WORK_OF_ART\n", "\n", ", was the last active ruler of \n", "\n", @@ -266,20 +266,30 @@ "source": [ "Much better!\n", "\n", - "But, what if we don't want to use a model with these labels? Well, this integration works for any [SpanMarker model on the Hugging Face Hub](https://huggingface.co/models?library=span-marker), so we can just pick another one. Let's now also ensure that the model stays on the CPU, just to see how that works." + "But, what if we don't want to use a model with these labels? Well, this integration works for any [SpanMarker model on the Hugging Face Hub](https://huggingface.co/models?library=span-marker), so we can just pick another one. Let's now also ensure that the model stays on the CPU, just to see how that works. Beyond that, we'll overwrite entities from spaCy's own NER model. This is recommended when the SpanMarker model uses a different label scheme than spaCy, which uses the labels from OntoNotes v5." ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "SpanMarker model predictions are being computed on the CPU while CUDA is available. Moving the model to CUDA using `model.cuda()` before performing predictions is heavily recommended to significantly boost prediction speeds.\n" - ] + "data": { + "text/html": [ + "
[11:12:24] WARNING  SpanMarker model predictions are being computed on the CPU while CUDA is        modeling.py:382\n",
+       "                    available. Moving the model to CUDA using `model.cuda()` before performing                     \n",
+       "                    predictions is heavily recommended to significantly boost prediction speeds.                   \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[11:12:24]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m SpanMarker model predictions are being computed on the CPU while CUDA is \u001b]8;id=964803;file://C:\\code\\span-marker-ner\\span_marker\\modeling.py\u001b\\\u001b[2mmodeling.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=664101;file://C:\\code\\span-marker-ner\\span_marker\\modeling.py#382\u001b\\\u001b[2m382\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m available. Moving the model to CUDA using `\u001b[1;35mmodel.cuda\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m` before performing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m predictions is heavily recommended to significantly boost prediction speeds. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stdout", @@ -328,6 +338,7 @@ " config={\n", " \"model\": \"tomaarsen/span-marker-xlm-roberta-base-fewnerd-fine-super\",\n", " \"device\": \"cpu\",\n", + " \"overwrite_entities\": True,\n", " },\n", ")\n", "\n", @@ -347,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -360,7 +371,7 @@ " (Paris, 'GPE')]" ] }, - "execution_count": 16, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } From cf0ad39b0aaa77553807355da2ea2c307f2745cd Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 24 Aug 2023 12:37:17 +0200 Subject: [PATCH 11/12] Reintroduce accidentally removed "Fixed" header --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bbb1df4..63262786 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,8 @@ Types of changes ## [1.2.4] +### Fixed + - Resolved broken license information. ## [1.2.3] From 115c3eab445aed8db7c45375da85fb8ee345960b Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 24 Aug 2023 12:42:26 +0200 Subject: [PATCH 12/12] Prefer SpanMarker outputs over spaCy outputs --- notebooks/spacy_integration.ipynb | 24 +++++++----------------- span_marker/spacy_integration.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/notebooks/spacy_integration.ipynb b/notebooks/spacy_integration.ipynb index 4023d308..3e245460 100644 --- a/notebooks/spacy_integration.ipynb +++ b/notebooks/spacy_integration.ipynb @@ -211,9 +211,9 @@ " PERSON\n", "
\n", ", also known as \n", - "\n", + "\n", " Cleopatra the Great\n", - " WORK_OF_ART\n", + " PERSON\n", "\n", ", was the last active ruler of \n", "\n", @@ -275,21 +275,11 @@ "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
[11:12:24] WARNING  SpanMarker model predictions are being computed on the CPU while CUDA is        modeling.py:382\n",
-       "                    available. Moving the model to CUDA using `model.cuda()` before performing                     \n",
-       "                    predictions is heavily recommended to significantly boost prediction speeds.                   \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[2;36m[11:12:24]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m SpanMarker model predictions are being computed on the CPU while CUDA is \u001b]8;id=964803;file://C:\\code\\span-marker-ner\\span_marker\\modeling.py\u001b\\\u001b[2mmodeling.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=664101;file://C:\\code\\span-marker-ner\\span_marker\\modeling.py#382\u001b\\\u001b[2m382\u001b[0m\u001b]8;;\u001b\\\n", - "\u001b[2;36m \u001b[0m available. Moving the model to CUDA using `\u001b[1;35mmodel.cuda\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m` before performing \u001b[2m \u001b[0m\n", - "\u001b[2;36m \u001b[0m predictions is heavily recommended to significantly boost prediction speeds. \u001b[2m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "SpanMarker model predictions are being computed on the CPU while CUDA is available. Moving the model to CUDA using `model.cuda()` before performing predictions is heavily recommended to significantly boost prediction speeds.\n" + ] }, { "name": "stdout", diff --git a/span_marker/spacy_integration.py b/span_marker/spacy_integration.py index 9778e3aa..8abb7c18 100644 --- a/span_marker/spacy_integration.py +++ b/span_marker/spacy_integration.py @@ -92,7 +92,7 @@ def set_ents(self, doc: Doc, ents: List[Span]): if self.overwrite_entities: doc.set_ents(ents) else: - doc.set_ents(filter_spans(list(doc.ents) + ents)) + doc.set_ents(filter_spans(ents + list(doc.ents))) def __call__(self, doc: Doc) -> Doc: """Fill `doc.ents` and `span.label_` using the chosen SpanMarker model."""