diff --git a/CHANGELOG.md b/CHANGELOG.md index dddc6f8754..ad2bb9bcbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ These are the section headers that we use: ## [Unreleased]() +### Added + +- Added `allow_overlapping` parameter for span questions. ([#4697](https://github.com/argilla-io/argilla/pull/4697)) + ## [1.26.1](https://github.com/argilla-io/argilla/compare/v1.26.0...v1.26.1) ### Added diff --git a/environment_dev.yml b/environment_dev.yml index a596ddd667..79f00c2cad 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -66,6 +66,6 @@ dependencies: - ipynbname>=2023.2.0.0 - httpx~=0.26.0 # For now we can just install argilla-server from the GitHub repo - - git+https://github.com/argilla-io/argilla-server.git + - git+https://github.com/argilla-io/argilla-server.git@feat/overlapped-span-questions # install Argilla in editable mode - -e .[listeners] diff --git a/pyproject.toml b/pyproject.toml index 70ba05dcb1..c04b5adfe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,8 +49,8 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -server = ["argilla-server ~= 1.26.1"] -server-postgresql = ["argilla-server[postgresql] ~= 1.26.1"] +server = ["argilla-server ~= 1.27.0.dev0"] +server-postgresql = ["argilla-server[postgresql] ~= 1.27.0.dev0"] listeners = ["schedule ~= 1.1.0", "prodict ~= 0.8.0"] integrations = [ "PyYAML >= 5.4.1,< 6.1.0", # Required by `argilla.client.feedback.config` just used in `HuggingFaceDatasetMixin` diff --git a/src/argilla/client/feedback/schemas/questions.py b/src/argilla/client/feedback/schemas/questions.py index b69f11d82b..c434ef52bc 100644 --- a/src/argilla/client/feedback/schemas/questions.py +++ b/src/argilla/client/feedback/schemas/questions.py @@ -358,6 +358,7 @@ class SpanQuestion(QuestionSchema): field: str = Field(..., description="The field in the input that the user will be asked to annotate.") labels: Union[Dict[str, str], conlist(Union[str, SpanLabelOption], min_items=1, unique_items=True)] visible_labels: Union[conint(ge=3), None] = _DEFAULT_MAX_VISIBLE_LABELS + allow_overlapping: bool = Field(False, description="Configure span to support overlap") @validator("labels", pre=True) def parse_labels_dict(cls, labels) -> List[SpanLabelOption]: @@ -408,6 +409,7 @@ def server_settings(self) -> Dict[str, Any]: "type": self.type, "field": self.field, "visible_options": self.visible_labels, + "allow_overlapping": self.allow_overlapping, "options": [label.dict() for label in self.labels], } diff --git a/src/argilla/client/feedback/schemas/remote/questions.py b/src/argilla/client/feedback/schemas/remote/questions.py index 92cf13ea2b..41fd15a618 100644 --- a/src/argilla/client/feedback/schemas/remote/questions.py +++ b/src/argilla/client/feedback/schemas/remote/questions.py @@ -160,6 +160,7 @@ def to_local(self) -> SpanQuestion: required=self.required, labels=self.labels, visible_labels=self.visible_labels, + allow_overlapping=self.allow_overlapping, ) @classmethod @@ -168,14 +169,16 @@ def _parse_options_from_api(cls, options: List[Dict[str, str]]) -> List[SpanLabe @classmethod def from_api(cls, payload: "FeedbackQuestionModel") -> "RemoteSpanQuestion": + question_settings = payload.settings return RemoteSpanQuestion( id=payload.id, name=payload.name, title=payload.title, - field=payload.settings["field"], + field=question_settings["field"], required=payload.required, - visible_labels=payload.settings["visible_options"], - labels=cls._parse_options_from_api(payload.settings["options"]), + visible_labels=question_settings["visible_options"], + labels=cls._parse_options_from_api(question_settings["options"]), + allow_overlapping=question_settings["allow_overlapping"], ) diff --git a/tests/integration/client/feedback/dataset/local/test_dataset.py b/tests/integration/client/feedback/dataset/local/test_dataset.py index 431190426a..89820b4299 100644 --- a/tests/integration/client/feedback/dataset/local/test_dataset.py +++ b/tests/integration/client/feedback/dataset/local/test_dataset.py @@ -78,9 +78,34 @@ def test_create_dataset_with_span_questions(argilla_user: "ServerUser") -> None: rg_dataset = ds.push_to_argilla(name="new_dataset") assert rg_dataset.id - assert rg_dataset.questions[0].name == "spans" - assert rg_dataset.questions[0].field == "text" - assert rg_dataset.questions[0].labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")] + question = rg_dataset.questions[0] + assert question.name == "spans" + assert question.field == "text" + assert question.labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")] + assert question.allow_overlapping is False + + +@pytest.mark.parametrize("allow_overlapping", [True, False]) +def test_create_dataset_with_span_questions_allow_overlapping( + argilla_user: "ServerUser", allow_overlapping: bool +) -> None: + argilla.client.singleton.init(api_key=argilla_user.api_key) + + ds = FeedbackDataset( + fields=[TextField(name="text")], + questions=[ + SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=allow_overlapping) + ], + ) + + rg_dataset = ds.push_to_argilla(name="new_dataset") + + assert rg_dataset.id + question = rg_dataset.questions[0] + assert question.name == "spans" + assert question.field == "text" + assert question.labels == [SpanLabelOption(value="label1"), SpanLabelOption(value="label2")] + assert question.allow_overlapping is allow_overlapping @pytest.mark.asyncio @@ -277,6 +302,71 @@ def test_add_records_with_wrong_spans_suggestions( ) +def test_add_records_with_overlapped_spans(argilla_user: "ServerUser") -> None: + argilla.client.singleton.init(api_key=argilla_user.api_key) + + dataset_cfg = FeedbackDataset( + fields=[TextField(name="text")], + questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=True)], + ) + + dataset = dataset_cfg.push_to_argilla(name="test-dataset") + question = dataset.question_by_name("spans") + + dataset.add_records( + [ + FeedbackRecord( + fields={"text": "this is a text"}, + suggestions=[ + question.suggestion( + value=[ + SpanValueSchema(start=0, end=4, label="label1"), + SpanValueSchema(start=1, end=2, label="label2"), + ] + ) + ], + ) + ] + ) + + assert len(dataset.records) == 1 + + record = dataset.records[0] + assert record.suggestions[0].value == [ + SpanValueSchema(start=0, end=4, label="label1"), + SpanValueSchema(start=1, end=2, label="label2"), + ] + + +def test_add_records_with_overlapped_spans_and_disabling_overlapping_span(argilla_user: "ServerUser") -> None: + argilla.client.singleton.init(api_key=argilla_user.api_key) + + dataset_cfg = FeedbackDataset( + fields=[TextField(name="text")], + questions=[SpanQuestion(name="spans", field="text", labels=["label1", "label2"], allow_overlapping=False)], + ) + + dataset = dataset_cfg.push_to_argilla(name="test-dataset") + question = dataset.question_by_name("spans") + + with pytest.raises(ValidationApiError, match="overlapping values found between spans at index idx=0 and idx=1"): + dataset.add_records( + [ + FeedbackRecord( + fields={"text": "this is a text"}, + suggestions=[ + question.suggestion( + value=[ + SpanValueSchema(start=0, end=4, label="label1"), + SpanValueSchema(start=1, end=2, label="label2"), + ] + ) + ], + ) + ] + ) + + def test_add_records_with_vectors() -> None: dataset = FeedbackDataset( fields=[TextField(name="text", required=True)], diff --git a/tests/unit/client/feedback/schemas/remote/test_questions.py b/tests/unit/client/feedback/schemas/remote/test_questions.py index 2b2103d713..a67286f2a6 100644 --- a/tests/unit/client/feedback/schemas/remote/test_questions.py +++ b/tests/unit/client/feedback/schemas/remote/test_questions.py @@ -460,6 +460,7 @@ def test_span_questions_from_api(): "type": "span", "field": "field", "visible_options": None, + "allow_overlapping": False, "options": [ {"text": "Span label a", "value": "a", "description": None}, { @@ -490,6 +491,7 @@ def test_span_questions_from_api_with_visible_labels(): "type": "span", "field": "field", "visible_options": 3, + "allow_overlapping": False, "options": [ {"text": "Span label a", "value": "a", "description": None}, {"text": "Span label b", "value": "b", "description": None}, diff --git a/tests/unit/client/feedback/schemas/test_questions.py b/tests/unit/client/feedback/schemas/test_questions.py index 0e45e46cb8..523513b5b5 100644 --- a/tests/unit/client/feedback/schemas/test_questions.py +++ b/tests/unit/client/feedback/schemas/test_questions.py @@ -455,6 +455,7 @@ def test_span_question() -> None: title="Question", description="Description", required=True, + allow_overlapping=True, labels=["a", "b"], ) @@ -463,6 +464,7 @@ def test_span_question() -> None: "type": "span", "field": "field", "visible_options": None, + "allow_overlapping": True, "options": [{"value": "a", "text": "a", "description": None}, {"value": "b", "text": "b", "description": None}], } @@ -481,6 +483,7 @@ def test_span_question_with_labels_dict() -> None: "type": "span", "field": "field", "visible_options": None, + "allow_overlapping": False, "options": [ {"value": "a", "text": "A text", "description": None}, {"value": "b", "text": "B text", "description": None}, @@ -503,6 +506,7 @@ def test_span_question_with_visible_labels() -> None: "type": "span", "field": "field", "visible_options": 3, + "allow_overlapping": False, "options": [ {"value": "a", "text": "a", "description": None}, {"value": "b", "text": "b", "description": None},