Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

Commit

Permalink
feat: overlapped span questions (#96)
Browse files Browse the repository at this point in the history
# Description

Feature branch to support span questions overlapping.

Refs argilla-io/argilla#1750

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <francis@argilla.io>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 16, 2024
1 parent 771e852 commit 876307a
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ These are the section headers that we use:

### Added

- Added `allow_overlapping` field for creation and update of span question settings. ([#89](https://github.com/argilla-io/argilla-server/pull/89))
- Added `ARGILLA_LABEL_SELECTION_OPTIONS_MAX_ITEMS` environment variable to set the number of maximum items to be used by label and multi label questions. By default this value is set to `500`. ([#85](https://github.com/argilla-io/argilla-server/pull/85))
- Added `ARGILLA_SPAN_OPTIONS_MAX_ITEMS` environment variable to set the number of maximum items to be used by span questions. By default this value is set to `500`. ([#85](https://github.com/argilla-io/argilla-server/pull/85))
- Added `GET /api/v1/datasets/:dataset_id/progress` endpoint to return progress metrics related with one specific dataset. ([#90](https://github.com/argilla-io/argilla-server/pull/90))
Expand Down
2 changes: 1 addition & 1 deletion argilla
Submodule argilla updated 60 files
+12 −1 CHANGELOG.md
+ docs/_source/_static/images/llms/training-setfit-absa/setfit-absa.png
+ docs/_source/_static/images/llms/training-setfit-absa/setfit-basa-argilla.png
+7 −2 docs/_source/getting_started/installation/configurations/server_configuration.md
+4 −4 docs/_source/practical_guides/create_update_dataset/vectors.md
+1,079 −0 docs/_source/tutorials_and_integrations/tutorials/feedback/training-setfit-absa.ipynb
+6 −0 docs/_source/tutorials_and_integrations/tutorials/tutorials.md
+1 −0 environment_dev.yml
+4 −3 frontend/components/base/base-card/BaseCardWithTabs.vue
+0 −157 frontend/components/base/base-progress/BaseLinearProgress.vue
+0 −50 frontend/components/base/base-progress/BaseLinearProgressSkeleton.vue
+5 −11 frontend/components/base/base-table/BaseTableInfo.vue
+5 −4 frontend/components/features/datasets/datasets-table/DatasetsTable.vue
+0 −52 frontend/components/features/datasets/datasets-table/dataset-progress/DatasetProgress.vue
+0 −62 frontend/components/features/datasets/datasets-table/dataset-progress/useDatasetProgressViewModel.ts
+1 −0 frontend/components/feedback-task/container/questions/form/ranking/Ranking.component.vue
+6 −12 ...omponents/feedback-task/container/questions/form/ranking/drag-and-drop-selection/DndSelection.component.vue
+1 −1 frontend/components/feedback-task/container/questions/form/rating/Rating.component.vue
+30 −20 frontend/components/feedback-task/container/questions/form/rating/RatingMonoSelection.component.vue
+1 −2 frontend/components/feedback-task/container/questions/form/rating/ratingMonoSelection.component.spec.js
+24 −20 ...nents/feedback-task/container/questions/form/shared-components/label-selection/LabelSelection.component.vue
+0 −21 ...nents/feedback-task/container/questions/form/shared-components/question-header/QuestionHeader.component.vue
+1 −0 frontend/components/feedback-task/container/questions/form/span/EntityLabelSelection.component.vue
+2 −22 frontend/components/feedback-task/container/questions/form/text-area/TextArea.component.vue
+1 −1 frontend/components/feedback-task/settings/useDeleteDatasetViewModel.ts
+1 −1 frontend/components/feedback-task/settings/useSettingInfoViewModel.ts
+1 −1 frontend/components/feedback-task/settings/useSettingsMetadataViewModel.ts
+20 −11 frontend/components/feedback-task/sidebar/sidebar-feedback-task/SidebarFeedbackTaskProgress.vue
+1 −1 frontend/pages/dataset/_id/useDatasetSettingViewModel.ts
+1 −1 frontend/pages/dataset/_id/useDatasetViewModel.ts
+3 −3 frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap
+0 −7 frontend/translation/en.js
+0 −5 frontend/v1/di/di.ts
+2 −10 frontend/v1/domain/entities/Dataset.ts
+4 −4 frontend/v1/domain/entities/DatasetSetting.ts
+0 −0 frontend/v1/domain/entities/Metrics.test.ts
+0 −0 frontend/v1/domain/entities/Metrics.ts
+0 −9 frontend/v1/domain/entities/dataset/Progress.ts
+1 −5 frontend/v1/domain/services/IDatasetRepository.ts
+1 −1 frontend/v1/domain/services/IDatasetSettingStorage.ts
+1 −1 frontend/v1/domain/services/IDatasetStorage.ts
+1 −1 frontend/v1/domain/services/IDatasetsStorage.ts
+1 −1 frontend/v1/domain/services/IMetricsStorage.ts
+1 −1 frontend/v1/domain/usecases/dataset-setting/get-dataset-settings-use-case.ts
+3 −3 frontend/v1/domain/usecases/dataset-setting/update-dataset-setting-use-case.ts
+2 −2 frontend/v1/domain/usecases/delete-dataset-use-case.ts
+0 −9 frontend/v1/domain/usecases/get-dataset-progress-use-case.ts
+1 −27 frontend/v1/infrastructure/repositories/DatasetRepository.ts
+1 −1 frontend/v1/infrastructure/repositories/MetricsRepository.ts
+1 −1 frontend/v1/infrastructure/services/useRoutes.ts
+1 −1 frontend/v1/infrastructure/storage/DatasetSettingStorage.ts
+1 −1 frontend/v1/infrastructure/storage/DatasetStorage.ts
+1 −1 frontend/v1/infrastructure/storage/DatasetsStorage.ts
+1 −1 frontend/v1/infrastructure/storage/MetricsStorage.ts
+0 −8 frontend/v1/infrastructure/types/dataset.ts
+1 −4 pyproject.toml
+16 −4 src/argilla/client/feedback/metrics/agreement_metrics.py
+1 −1 src/argilla/listeners/listener.py
+9 −4 src/argilla/listeners/models.py
+1 −1 tests/integration/client/test_workspaces.py
8 changes: 1 addition & 7 deletions src/argilla_server/contexts/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
from typing import Union
from uuid import UUID

from sqlalchemy import func, select
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

import argilla_server.errors.future as errors
from argilla_server.enums import QuestionType
from argilla_server.models import Dataset, Question, User
from argilla_server.policies import QuestionPolicyV1, authorize
from argilla_server.schemas.v1.questions import (
LabelSelectionQuestionSettings,
LabelSelectionSettingsUpdate,
QuestionCreate,
QuestionSettings,
QuestionSettingsUpdate,
QuestionUpdate,
SpanQuestionSettingsCreate,
)
from argilla_server.validators.questions import (
QuestionCreateValidator,
Expand Down
2 changes: 2 additions & 0 deletions src/argilla_server/schemas/v1/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class SpanQuestionSettingsCreate(UniqueValuesCheckerMixin):
max_items=settings.span_options_max_items,
)
visible_options: Optional[int] = Field(None, ge=SPAN_MIN_VISIBLE_OPTIONS)
allow_overlapping: bool = False

@root_validator(skip_on_failure=True)
def check_visible_options_value(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -271,6 +272,7 @@ class SpanQuestionSettingsUpdate(UpdateSchema):
)
]
visible_options: Optional[int] = Field(None, ge=SPAN_MIN_VISIBLE_OPTIONS)
allow_overlapping: Optional[bool]


QuestionSettings = Annotated[
Expand Down
20 changes: 19 additions & 1 deletion src/argilla_server/validators/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

from argilla_server.enums import QuestionType
from argilla_server.models.database import Dataset, Question
from argilla_server.schemas.v1.questions import QuestionCreate, QuestionSettings, QuestionSettingsUpdate, QuestionUpdate
from argilla_server.schemas.v1.questions import (
QuestionCreate,
QuestionSettings,
QuestionSettingsUpdate,
QuestionUpdate,
SpanQuestionSettings,
)


class InvalidQuestionSettings(Exception):
Expand Down Expand Up @@ -76,6 +82,7 @@ def _validate_question_settings(self, question_settings: QuestionSettings):
self._validate_question_settings_type_is_the_same(question_settings, self._question_update.settings)
self._validate_question_settings_label_options(question_settings, self._question_update.settings)
self._validate_question_settings_visible_options(question_settings, self._question_update.settings)
self._validate_span_question_settings(question_settings, self._question_update.settings)

def _validate_question_settings_type_is_the_same(
self, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
Expand Down Expand Up @@ -127,6 +134,17 @@ def _validate_question_settings_visible_options(
f"the value for 'visible_options' must be less or equal to the number of items in 'options' ({number_of_options})"
)

def _validate_span_question_settings(
self, question_settings: SpanQuestionSettings, question_settings_update: QuestionSettingsUpdate
) -> None:
if question_settings_update.type != QuestionType.span:
return

if question_settings.allow_overlapping and not question_settings_update.allow_overlapping:
raise InvalidQuestionSettings(
"'allow_overlapping' can't be disabled because responses may become inconsistent"
)


class QuestionDeleteValidator:
def validate_for(self, dataset: Dataset):
Expand Down
17 changes: 14 additions & 3 deletions src/argilla_server/validators/response_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
MultiLabelSelectionQuestionResponseValue,
RankingQuestionResponseValue,
RatingQuestionResponseValue,
ResponseCreate,
ResponseValueCreate,
ResponseValuesCreate,
ResponseValueTypes,
SpanQuestionResponseValue,
TextAndLabelSelectionQuestionResponseValue,
Expand Down Expand Up @@ -222,6 +219,7 @@ def validate_for(self, span_question_settings: SpanQuestionSettings, record: Rec
self._validate_question_settings_field_is_present_at_record(span_question_settings, record)
self._validate_start_end_are_within_record_field_limits(span_question_settings, record)
self._validate_labels_are_available_at_question_settings(span_question_settings)
self._validate_values_are_not_overlapped(span_question_settings)

def _validate_value_type(self) -> None:
if not isinstance(self._response_value, list):
Expand Down Expand Up @@ -257,3 +255,16 @@ def _validate_labels_are_available_at_question_settings(self, span_question_sett
raise ValueError(
f"undefined label '{value_item.label}' for span question.\nValid labels are: {available_labels!r}"
)

def _validate_values_are_not_overlapped(self, span_question_settings: SpanQuestionSettings) -> None:
if span_question_settings.allow_overlapping:
return

for span_i, value_item in enumerate(self._response_value):
for span_j, other_value_item in enumerate(self._response_value):
if (
span_i != span_j
and value_item.start < other_value_item.end
and value_item.end > other_value_item.start
):
raise ValueError(f"overlapping values found between spans at index idx={span_i} and idx={span_j}")
43 changes: 26 additions & 17 deletions tests/unit/api/v1/datasets/test_create_dataset_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from datetime import datetime
from typing import Union
from uuid import UUID

import pytest
Expand All @@ -30,30 +31,38 @@ class TestCreateDatasetQuestion:
def url(self, dataset_id: UUID) -> str:
return f"/api/v1/datasets/{dataset_id}/questions"

@pytest.mark.parametrize(
"allow_overlapping,expected_allow_overlapping", [(None, False), (False, False), (True, True)]
)
async def test_create_dataset_span_question(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict
self,
async_client: AsyncClient,
db: AsyncSession,
owner_auth_header: dict,
allow_overlapping: Union[bool, None],
expected_allow_overlapping: bool,
):
dataset = await DatasetFactory.create()
await TextFieldFactory.create(name="field-a", dataset=dataset)

settings = {
"type": QuestionType.span,
"field": "field-a",
"visible_options": 3,
"options": [
{"value": "label-a", "text": "Label A", "description": "Label A description"},
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
}

if allow_overlapping is not None:
settings["allow_overlapping"] = allow_overlapping

response = await async_client.post(
self.url(dataset.id),
headers=owner_auth_header,
json={
"name": "name",
"title": "Title",
"description": "Description",
"settings": {
"type": QuestionType.span,
"field": "field-a",
"visible_options": 3,
"options": [
{"value": "label-a", "text": "Label A", "description": "Label A description"},
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
},
},
json={"name": "name", "title": "Title", "description": "Description", "settings": settings},
)

assert response.status_code == 201
Expand All @@ -75,7 +84,7 @@ async def test_create_dataset_span_question(
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
"allow_overlapping": False,
"allow_overlapping": expected_allow_overlapping,
"allow_character_annotation": True,
},
"dataset_id": str(dataset.id),
Expand Down
83 changes: 81 additions & 2 deletions tests/unit/api/v1/questions/test_update_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from uuid import UUID

import pytest
from argilla_server.enums import QuestionType
from httpx import AsyncClient

from tests.factories import LabelSelectionQuestionFactory, TextQuestionFactory
from tests.factories import LabelSelectionQuestionFactory, SpanQuestionFactory, TextQuestionFactory


@pytest.mark.asyncio
Expand Down Expand Up @@ -111,3 +111,82 @@ async def test_update_question_with_more_visible_options_than_allowed(
assert response.json() == {
"detail": "the value for 'visible_options' must be less or equal to the number of items in 'options' (3)"
}

async def test_update_span_question_enabling_allow_overlapping(
self, async_client: AsyncClient, owner_auth_header: dict
):
question = await SpanQuestionFactory.create(
settings={
"type": QuestionType.span.value,
"field": "field-a",
"options": [
{"value": "label-a", "text": "Label A", "description": "Label A description"},
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
"allow_overlapping": False,
}
)

response = await async_client.patch(
self.url(question.id),
headers=owner_auth_header,
json={
"settings": {"type": QuestionType.span, "allow_overlapping": True},
},
)

assert response.status_code == 200

response_json = response.json()
assert response_json == {
"id": str(question.id),
"name": question.name,
"description": question.description,
"title": question.title,
"dataset_id": str(question.dataset_id),
"required": False,
"settings": {
"type": QuestionType.span.value,
"field": "field-a",
"options": [
{"value": "label-a", "text": "Label A", "description": "Label A description"},
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
"allow_overlapping": True,
"allow_character_annotation": True,
"visible_options": None,
},
"inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(),
"updated_at": datetime.fromisoformat(response_json["updated_at"]).isoformat(),
}

async def test_update_span_question_disabling_allow_overlapping(
self, async_client: AsyncClient, owner_auth_header: dict
):
question = await SpanQuestionFactory.create(
settings={
"type": QuestionType.span.value,
"field": "field-a",
"options": [
{"value": "label-a", "text": "Label A", "description": "Label A description"},
{"value": "label-b", "text": "Label B", "description": "Label B description"},
{"value": "label-c", "text": "Label C", "description": "Label C description"},
],
"allow_overlapping": True,
}
)

response = await async_client.patch(
self.url(question.id),
headers=owner_auth_header,
json={
"settings": {"type": QuestionType.span, "allow_overlapping": False},
},
)

assert response.status_code == 422
assert response.json() == {
"detail": "'allow_overlapping' can't be disabled because responses may become inconsistent"
}
34 changes: 33 additions & 1 deletion tests/unit/api/v1/records/test_create_record_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,39 @@ async def test_create_record_response_for_span_question_with_non_existent_label(

assert response.status_code == 422
assert response.json() == {
"detail": "undefined label 'label-non-existent' for span question.\nValid labels are: ['label-a', 'label-b', 'label-c']"
"detail": "undefined label 'label-non-existent' for span question.\n"
"Valid labels are: ['label-a', 'label-b', 'label-c']"
}

assert (await db.execute(select(func.count(Response.id)))).scalar() == 0

async def test_create_record_response_for_span_question_with_overlapped_values(
self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

await SpanQuestionFactory.create(name="span-question", dataset=dataset)

record = await RecordFactory.create(fields={"field-a": "Hello, this is a text field"}, dataset=dataset)

response = await async_client.post(
self.url(record.id),
headers=owner_auth_header,
json={
"values": {
"span-question": {
"value": [
{"label": "label-a", "start": 0, "end": 3},
{"label": "label-a", "start": 6, "end": 8},
{"label": "label-b", "start": 2, "end": 5},
],
},
},
"status": ResponseStatusFilter.submitted,
},
)

assert response.status_code == 422
assert response.json() == {"detail": "overlapping values found between spans at index idx=0 and idx=2"}

assert (await db.execute(select(func.count(Response.id)))).scalar() == 0
31 changes: 30 additions & 1 deletion tests/unit/api/v1/records/test_upsert_suggestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,36 @@ async def test_upsert_suggestion_for_span_question_with_non_existent_label(

assert response.status_code == 422
assert response.json() == {
"detail": "undefined label 'label-non-existent' for span question.\nValid labels are: ['label-a', 'label-b', 'label-c']"
"detail": "undefined label 'label-non-existent' for span question.\n"
"Valid labels are: ['label-a', 'label-b', 'label-c']"
}

assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

async def test_upsert_suggestion_for_span_question_with_overlapped_values(
self, async_client: AsyncClient, db: AsyncSession, owner_auth_header: dict
):
dataset = await DatasetFactory.create()

span_question = await SpanQuestionFactory.create(name="span-question", dataset=dataset)

record = await RecordFactory.create(fields={"field-a": "Hello, this is a text field"}, dataset=dataset)

response = await async_client.put(
self.url(record.id),
headers=owner_auth_header,
json={
"question_id": str(span_question.id),
"type": SuggestionType.model,
"value": [
{"label": "label-a", "start": 0, "end": 3},
{"label": "label-a", "start": 6, "end": 8},
{"label": "label-b", "start": 2, "end": 5},
],
},
)

assert response.status_code == 422
assert response.json() == {"detail": "overlapping values found between spans at index idx=0 and idx=2"}

assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

0 comments on commit 876307a

Please sign in to comment.