diff --git a/src/argilla_server/contexts/questions.py b/src/argilla_server/contexts/questions.py index 0cabfbbe..66250433 100644 --- a/src/argilla_server/contexts/questions.py +++ b/src/argilla_server/contexts/questions.py @@ -50,6 +50,7 @@ def _validate_settings_type(settings: QuestionSettings, settings_update: Questio def _validate_label_options(settings: LabelSelectionQuestionSettings, settings_update: LabelSelectionSettingsUpdate): + # TODO: Validate visible_options on update if settings_update.options is None: return @@ -142,7 +143,9 @@ async def create_question(db: AsyncSession, dataset: Dataset, question_create: Q ) -async def update_question(db: AsyncSession, question_id: UUID, question_update: QuestionUpdate, current_user: User): +async def update_question( + db: AsyncSession, question_id: UUID, question_update: QuestionUpdate, current_user: User +) -> Question: question = await get_question_by_id(db, question_id) if not question: raise errors.NotFoundError() diff --git a/src/argilla_server/schemas/v1/questions.py b/src/argilla_server/schemas/v1/questions.py index a50194d8..aa080dfe 100644 --- a/src/argilla_server/schemas/v1/questions.py +++ b/src/argilla_server/schemas/v1/questions.py @@ -16,11 +16,8 @@ from typing import Any, Dict, List, Literal, Optional, Union from uuid import UUID -from sqlalchemy.ext.asyncio import AsyncSession - from argilla_server.models import QuestionType -from argilla_server.models.database import Dataset -from argilla_server.pydantic_v1 import BaseModel, Field, PositiveInt, conlist, constr, root_validator, validator +from argilla_server.pydantic_v1 import BaseModel, Field, conlist, constr, root_validator, validator from argilla_server.schemas.base import UpdateSchema from argilla_server.schemas.v1.fields import FieldName @@ -61,63 +58,58 @@ SPAN_OPTIONS_MIN_ITEMS = 1 SPAN_OPTIONS_MAX_ITEMS = 500 +SPAN_MIN_VISIBLE_OPTIONS = 3 -class TextQuestionSettings(BaseModel): - type: Literal[QuestionType.text] - use_markdown: bool = False - - -class RatingQuestionSettingsOption(BaseModel): - value: int - - -class RatingQuestionSettings(BaseModel): - type: Literal[QuestionType.rating] - options: conlist(item_type=RatingQuestionSettingsOption) +class UniqueValuesCheckerMixin(BaseModel): + @root_validator(skip_on_failure=True) + def check_unique_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: + options = values.get("options", []) + seen = set() + duplicates = set() + for option in options: + if option.value in seen: + duplicates.add(option.value) + else: + seen.add(option.value) + if duplicates: + raise ValueError(f"Option values must be unique, found duplicates: {duplicates}") + return values +# Option-based settings class OptionSettings(BaseModel): value: str text: str description: Optional[str] = None -class LabelSelectionQuestionSettings(BaseModel): - type: Literal[QuestionType.label_selection] - options: conlist(item_type=OptionSettings) - visible_options: Optional[int] = None - - -class MultiLabelSelectionQuestionSettings(LabelSelectionQuestionSettings): - type: Literal[QuestionType.multi_label_selection] - - -class RankingQuestionSettings(BaseModel): - type: Literal[QuestionType.ranking] - options: conlist(item_type=OptionSettings) +class OptionSettingsCreate(BaseModel): + value: constr( + min_length=VALUE_TEXT_OPTION_VALUE_MIN_LENGTH, + max_length=VALUE_TEXT_OPTION_VALUE_MAX_LENGTH, + ) + text: constr( + min_length=VALUE_TEXT_OPTION_TEXT_MIN_LENGTH, + max_length=VALUE_TEXT_OPTION_TEXT_MAX_LENGTH, + ) + description: Optional[ + constr( + min_length=VALUE_TEXT_OPTION_DESCRIPTION_MIN_LENGTH, + max_length=VALUE_TEXT_OPTION_DESCRIPTION_MAX_LENGTH, + ) + ] = None -class SpanQuestionSettings(BaseModel): - type: Literal[QuestionType.span] - field: str - options: conlist(item_type=OptionSettings) - # These attributes are read-only for now - allow_overlapping: bool = Field(default=False, description="Allow spans overlapping") - allow_character_annotation: bool = Field(default=True, description="Allow character-level annotation") +# Text question +class TextQuestionSettings(BaseModel): + type: Literal[QuestionType.text] + use_markdown: bool = False -QuestionSettings = Annotated[ - Union[ - TextQuestionSettings, - RatingQuestionSettings, - LabelSelectionQuestionSettings, - MultiLabelSelectionQuestionSettings, - RankingQuestionSettings, - SpanQuestionSettings, - ], - Field(..., discriminator="type"), -] +class TextQuestionSettingsCreate(BaseModel): + type: Literal[QuestionType.text] + use_markdown: bool = False class TextQuestionSettingsUpdate(UpdateSchema): @@ -127,89 +119,14 @@ class TextQuestionSettingsUpdate(UpdateSchema): __non_nullable_fields__ = {"use_markdown"} -class RatingQuestionSettingsUpdate(UpdateSchema): - type: Literal[QuestionType.rating] - - -class LabelSelectionSettingsUpdate(UpdateSchema): - type: Literal[QuestionType.label_selection] - visible_options: Optional[PositiveInt] - options: Optional[conlist(item_type=OptionSettings)] - - -class MultiLabelSelectionQuestionSettingsUpdate(LabelSelectionSettingsUpdate): - type: Literal[QuestionType.multi_label_selection] - - -class RankingQuestionSettingsUpdate(UpdateSchema): - type: Literal[QuestionType.ranking] - - -class SpanQuestionSettingsUpdate(UpdateSchema): - type: Literal[QuestionType.span] - options: Optional[conlist(item_type=OptionSettings)] - - -QuestionSettingsUpdate = Annotated[ - Union[ - TextQuestionSettingsUpdate, - RatingQuestionSettingsUpdate, - LabelSelectionSettingsUpdate, - MultiLabelSelectionQuestionSettingsUpdate, - RankingQuestionSettingsUpdate, - SpanQuestionSettingsUpdate, - ], - Field(..., discriminator="type"), -] - - -QuestionName = Annotated[ - constr( - regex=QUESTION_CREATE_NAME_REGEX, - min_length=QUESTION_CREATE_NAME_MIN_LENGTH, - max_length=QUESTION_CREATE_NAME_MAX_LENGTH, - ), - Field(..., description="The name of the question"), -] - - -QuestionTitle = Annotated[ - constr( - min_length=QUESTION_CREATE_TITLE_MIN_LENGTH, - max_length=QUESTION_CREATE_TITLE_MAX_LENGTH, - ), - Field(..., description="The title of the question"), -] - - -QuestionDescription = Annotated[ - constr( - min_length=QUESTION_CREATE_DESCRIPTION_MIN_LENGTH, - max_length=QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, - ), - Field(..., description="The description of the question"), -] - - -class UniqueValuesCheckerMixin(BaseModel): - @root_validator(skip_on_failure=True) - def check_unique_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: - options = values.get("options", []) - seen = set() - duplicates = set() - for option in options: - if option.value in seen: - duplicates.add(option.value) - else: - seen.add(option.value) - if duplicates: - raise ValueError(f"Option values must be unique, found duplicates: {duplicates}") - return values +# Rating question +class RatingQuestionSettingsOption(BaseModel): + value: int -class TextQuestionSettingsCreate(BaseModel): - type: Literal[QuestionType.text] - use_markdown: bool = False +class RatingQuestionSettings(BaseModel): + type: Literal[QuestionType.rating] + options: conlist(item_type=RatingQuestionSettingsOption) class RatingQuestionSettingsCreate(UniqueValuesCheckerMixin): @@ -232,21 +149,15 @@ def check_option_value_range(cls, options: List[RatingQuestionSettingsOption]): return options -class OptionSettingsCreate(BaseModel): - value: constr( - min_length=VALUE_TEXT_OPTION_VALUE_MIN_LENGTH, - max_length=VALUE_TEXT_OPTION_VALUE_MAX_LENGTH, - ) - text: constr( - min_length=VALUE_TEXT_OPTION_TEXT_MIN_LENGTH, - max_length=VALUE_TEXT_OPTION_TEXT_MAX_LENGTH, - ) - description: Optional[ - constr( - min_length=VALUE_TEXT_OPTION_DESCRIPTION_MIN_LENGTH, - max_length=VALUE_TEXT_OPTION_DESCRIPTION_MAX_LENGTH, - ) - ] = None +class RatingQuestionSettingsUpdate(UpdateSchema): + type: Literal[QuestionType.rating] + + +# Label selection question +class LabelSelectionQuestionSettings(BaseModel): + type: Literal[QuestionType.label_selection] + options: List[OptionSettings] + visible_options: Optional[int] = None class LabelSelectionQuestionSettingsCreate(UniqueValuesCheckerMixin): @@ -272,10 +183,37 @@ def check_visible_options_value(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values +class LabelSelectionSettingsUpdate(UpdateSchema): + type: Literal[QuestionType.label_selection] + visible_options: Optional[int] = Field(None, ge=LABEL_SELECTION_MIN_VISIBLE_OPTIONS) + options: Optional[ + conlist( + item_type=OptionSettings, + min_items=LABEL_SELECTION_OPTIONS_MIN_ITEMS, + max_items=LABEL_SELECTION_OPTIONS_MAX_ITEMS, + ) + ] + + +# Multi-label selection question +class MultiLabelSelectionQuestionSettings(LabelSelectionQuestionSettings): + type: Literal[QuestionType.multi_label_selection] + + class MultiLabelSelectionQuestionSettingsCreate(LabelSelectionQuestionSettingsCreate): type: Literal[QuestionType.multi_label_selection] +class MultiLabelSelectionQuestionSettingsUpdate(LabelSelectionSettingsUpdate): + type: Literal[QuestionType.multi_label_selection] + + +# Ranking question +class RankingQuestionSettings(BaseModel): + type: Literal[QuestionType.ranking] + options: List[OptionSettings] + + class RankingQuestionSettingsCreate(UniqueValuesCheckerMixin): type: Literal[QuestionType.ranking] options: conlist( @@ -285,6 +223,21 @@ class RankingQuestionSettingsCreate(UniqueValuesCheckerMixin): ) +class RankingQuestionSettingsUpdate(UpdateSchema): + type: Literal[QuestionType.ranking] + + +# Span question +class SpanQuestionSettings(BaseModel): + type: Literal[QuestionType.span] + field: str + options: List[OptionSettings] + visible_options: Optional[int] = None + # These attributes are read-only for now + allow_overlapping: bool = Field(default=False, description="Allow spans overlapping") + allow_character_annotation: bool = Field(default=True, description="Allow character-level annotation") + + class SpanQuestionSettingsCreate(UniqueValuesCheckerMixin): type: Literal[QuestionType.span] field: FieldName @@ -293,6 +246,73 @@ class SpanQuestionSettingsCreate(UniqueValuesCheckerMixin): min_items=SPAN_OPTIONS_MIN_ITEMS, max_items=SPAN_OPTIONS_MAX_ITEMS, ) + visible_options: Optional[int] = Field(None, ge=SPAN_MIN_VISIBLE_OPTIONS) + + @root_validator(skip_on_failure=True) + def check_visible_options_value(cls, values: Dict[str, Any]) -> Dict[str, Any]: + visible_options = values.get("visible_options") + if visible_options is not None: + num_options = len(values["options"]) + if visible_options > num_options: + raise ValueError( + "The value for 'visible_options' must be less or equal to the number of items in 'options'" + f" ({num_options})" + ) + + return values + + +class SpanQuestionSettingsUpdate(UpdateSchema): + type: Literal[QuestionType.span] + options: Optional[ + conlist( + item_type=OptionSettings, + min_items=SPAN_OPTIONS_MIN_ITEMS, + max_items=SPAN_OPTIONS_MAX_ITEMS, + ) + ] + visible_options: Optional[int] = Field(None, ge=SPAN_MIN_VISIBLE_OPTIONS) + + +QuestionSettings = Annotated[ + Union[ + TextQuestionSettings, + RatingQuestionSettings, + LabelSelectionQuestionSettings, + MultiLabelSelectionQuestionSettings, + RankingQuestionSettings, + SpanQuestionSettings, + ], + Field(..., discriminator="type"), +] + + +QuestionName = Annotated[ + constr( + regex=QUESTION_CREATE_NAME_REGEX, + min_length=QUESTION_CREATE_NAME_MIN_LENGTH, + max_length=QUESTION_CREATE_NAME_MAX_LENGTH, + ), + Field(..., description="The name of the question"), +] + + +QuestionTitle = Annotated[ + constr( + min_length=QUESTION_CREATE_TITLE_MIN_LENGTH, + max_length=QUESTION_CREATE_TITLE_MAX_LENGTH, + ), + Field(..., description="The title of the question"), +] + + +QuestionDescription = Annotated[ + constr( + min_length=QUESTION_CREATE_DESCRIPTION_MIN_LENGTH, + max_length=QUESTION_CREATE_DESCRIPTION_MAX_LENGTH, + ), + Field(..., description="The description of the question"), +] QuestionSettingsCreate = Annotated[ @@ -308,6 +328,19 @@ class SpanQuestionSettingsCreate(UniqueValuesCheckerMixin): ] +QuestionSettingsUpdate = Annotated[ + Union[ + TextQuestionSettingsUpdate, + RatingQuestionSettingsUpdate, + LabelSelectionSettingsUpdate, + MultiLabelSelectionQuestionSettingsUpdate, + RankingQuestionSettingsUpdate, + SpanQuestionSettingsUpdate, + ], + Field(..., discriminator="type"), +] + + class Question(BaseModel): id: UUID name: str diff --git a/tests/factories.py b/tests/factories.py index 9d3199ff..9a24ca65 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -371,6 +371,7 @@ class SpanQuestionFactory(QuestionFactory): settings = { "type": QuestionType.span.value, "field": "field-a", + "visible_options": None, "options": [ {"value": "label-a", "text": "Label A", "description": "Label A description"}, {"value": "label-b", "text": "Label B", "description": "Label B description"}, diff --git a/tests/unit/api/v1/datasets/test_create_dataset_question.py b/tests/unit/api/v1/datasets/test_create_dataset_question.py index 4fc40b90..178f9f86 100644 --- a/tests/unit/api/v1/datasets/test_create_dataset_question.py +++ b/tests/unit/api/v1/datasets/test_create_dataset_question.py @@ -46,9 +46,11 @@ async def test_create_dataset_span_question( "settings": { "type": QuestionType.span, "field": "field-a", + "visible_options": 3, "options": [ {"value": "label-a", "text": "Label A", "description": "Label A description"}, {"value": "label-b", "text": "Label B", "description": "Label B description"}, + {"value": "label-c", "text": "Label C", "description": "Label C description"}, ], }, }, @@ -67,9 +69,11 @@ async def test_create_dataset_span_question( "settings": { "type": QuestionType.span, "field": "field-a", + "visible_options": 3, "options": [ {"value": "label-a", "text": "Label A", "description": "Label A description"}, {"value": "label-b", "text": "Label B", "description": "Label B description"}, + {"value": "label-c", "text": "Label C", "description": "Label C description"}, ], "allow_overlapping": False, "allow_character_annotation": True, @@ -138,3 +142,43 @@ async def test_create_dataset_question_with_other_span_question_using_the_same_f assert response.status_code == 422 assert response.json() == {"detail": f"'field-a' is already used by span question with id '{question.id}'"} + + @pytest.mark.parametrize( + "visible_options,error_msg", + [ + (1, "ensure this value is greater than or equal to 3"), + (4, "The value for 'visible_options' must be less or equal to the number of items in 'options' (3)"), + ], + ) + async def test_create_question_with_wrong_visible_options( + self, + async_client: AsyncClient, + db: AsyncSession, + owner_auth_header: dict, + visible_options: int, + error_msg: str, + ): + dataset = await DatasetFactory.create() + await TextFieldFactory.create(name="field-a", dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "name": "name", + "title": "Title", + "settings": { + "type": QuestionType.span, + "field": "field-a", + "visible_options": visible_options, + "options": [ + {"value": "label-a", "text": "Label A", "description": "Label A description"}, + {"value": "label-b", "text": "Label B", "description": "Label B description"}, + {"value": "label-c", "text": "Label C", "description": "Label C description"}, + ], + }, + }, + ) + + assert response.status_code == 422 + assert error_msg in response.text diff --git a/tests/unit/api/v1/datasets/test_questions.py b/tests/unit/api/v1/datasets/test_questions.py index f483c423..f76ba72c 100644 --- a/tests/unit/api/v1/datasets/test_questions.py +++ b/tests/unit/api/v1/datasets/test_questions.py @@ -156,6 +156,33 @@ class TestDatasetQuestions: { "type": "span", "field": "field-a", + "visible_options": None, + "options": [ + {"value": "label-a", "text": "Label A", "description": None}, + {"value": "label-b", "text": "Label B", "description": None}, + {"value": "label-c", "text": "Label C", "description": None}, + {"value": "label-d", "text": "Label D", "description": None}, + ], + "allow_character_annotation": True, + "allow_overlapping": False, + }, + ), + ( + { + "type": "span", + "field": "field-a", + "visible_options": 3, + "options": [ + {"value": "label-a", "text": "Label A"}, + {"value": "label-b", "text": "Label B"}, + {"value": "label-c", "text": "Label C"}, + {"value": "label-d", "text": "Label D"}, + ], + }, + { + "type": "span", + "field": "field-a", + "visible_options": 3, "options": [ {"value": "label-a", "text": "Label A", "description": None}, {"value": "label-b", "text": "Label B", "description": None}, diff --git a/tests/unit/api/v1/test_questions.py b/tests/unit/api/v1/test_questions.py index 12e9c061..07e896c8 100644 --- a/tests/unit/api/v1/test_questions.py +++ b/tests/unit/api/v1/test_questions.py @@ -201,6 +201,28 @@ ], "allow_overlapping": False, "allow_character_annotation": True, + "visible_options": None, + }, + ), + ( + SpanQuestionFactory, + { + "settings": { + "type": "span", + "visible_options": 3, + } + }, + { + "type": "span", + "field": "field-a", + "options": [ + {"value": "label-a", "text": "Label A", "description": "Label A description"}, + {"value": "label-b", "text": "Label B", "description": "Label B description"}, + {"value": "label-c", "text": "Label C", "description": "Label C description"}, + ], + "allow_overlapping": False, + "allow_character_annotation": True, + "visible_options": 3, }, ), ],