diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 9759cb2e3..a81282f60 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -98,3 +98,4 @@ PromptResponseClassification, ) from lbox.exceptions import * +from labelbox.schema.taskstatus import TaskStatus diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 956f12487..1926957a5 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -79,6 +79,7 @@ from labelbox.schema.slice import CatalogSlice, ModelSlice from labelbox.schema.task import DataUpsertTask, Task from labelbox.schema.user import User +from labelbox.schema.taskstatus import TaskStatus logger = logging.getLogger(__name__) @@ -90,6 +91,9 @@ class Client: top-level data objects (Projects, Datasets). """ + # Class variable to cache task types + _cancelable_task_types = None + def __init__( self, api_key=None, @@ -2390,9 +2394,31 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: task._user = user return task + def _get_cancelable_task_types(self): + """Internal method that returns a list of task types that can be canceled. + + The result is cached after the first call to avoid unnecessary API requests. + + Returns: + List[str]: List of cancelable task types in snake_case format + """ + if self._cancelable_task_types is None: + query = """query GetCancelableTaskTypesPyApi { + cancelableTaskTypes + }""" + + result = self.execute(query).get("cancelableTaskTypes", []) + # Reformat to kebab case + self._cancelable_task_types = [ + utils.snake_case(task_type).replace("_", "-") + for task_type in result + ] + + return self._cancelable_task_types + def cancel_task(self, task_id: str) -> bool: """ - Cancels a task with the given ID. + Cancels a task with the given ID if the task type is cancelable and the task is in progress. Args: task_id (str): The ID of the task to cancel. @@ -2401,8 +2427,26 @@ def cancel_task(self, task_id: str) -> bool: bool: True if the task was successfully cancelled. Raises: - LabelboxError: If the task could not be cancelled. + LabelboxError: If the task could not be cancelled, if the task type is not cancelable, + or if the task is not in progress. + ResourceNotFoundError: If the task does not exist (raised by get_task_by_id). """ + # Get the task object to check its type and status + task = self.get_task_by_id(task_id) + + # Check if task type is cancelable + cancelable_types = self._get_cancelable_task_types() + if task.type not in cancelable_types: + raise LabelboxError( + f"Task type '{task.type}' cannot be cancelled. Cancelable types are: {cancelable_types}" + ) + + # Check if task is in progress + if task.status_as_enum != TaskStatus.In_Progress: + raise LabelboxError( + f"Task cannot be cancelled because it is not in progress. Current status: {task.status}" + ) + mutation_str = """ mutation CancelTaskPyApi($id: ID!) { cancelBulkOperationJob(id: $id) { diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py index d6b74de68..b22a54854 100644 --- a/libs/labelbox/src/labelbox/schema/__init__.py +++ b/libs/labelbox/src/labelbox/schema/__init__.py @@ -26,3 +26,4 @@ import labelbox.schema.catalog import labelbox.schema.ontology_kind import labelbox.schema.project_overview +import labelbox.schema.taskstatus diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py index bd416e997..1eea3aebf 100644 --- a/libs/labelbox/src/labelbox/schema/organization.py +++ b/libs/labelbox/src/labelbox/schema/organization.py @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs): projects = Relationship.ToMany("Project", True) webhooks = Relationship.ToMany("Webhook", False) resource_tags = Relationship.ToMany("ResourceTags", False) + tasks = Relationship.ToMany("Task", False, "tasks") def invite_user( self, diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index d536b2560..04fd7b12f 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -14,6 +14,7 @@ from labelbox.schema.internal.datarow_upload_constants import ( DOWNLOAD_RESULT_PAGE_SIZE, ) +from labelbox.schema.taskstatus import TaskStatus if TYPE_CHECKING: from labelbox import User @@ -45,6 +46,9 @@ class Task(DbObject): created_at = Field.DateTime("created_at") name = Field.String("name") status = Field.String("status") + status_as_enum = Field.Enum( + TaskStatus, "status_as_enum", "status" + ) # additional status for filtering completion_percentage = Field.Float("completion_percentage") result_url = Field.String("result_url", "result") errors_url = Field.String("errors_url", "errors") diff --git a/libs/labelbox/src/labelbox/schema/taskstatus.py b/libs/labelbox/src/labelbox/schema/taskstatus.py new file mode 100644 index 000000000..0abbce1ca --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/taskstatus.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class TaskStatus(str, Enum): + In_Progress = "IN_PROGRESS" + Complete = "COMPLETE" + Canceling = "CANCELLING" + Canceled = "CANCELED" + Failed = "FAILED" + Unknown = "UNKNOWN" + + @classmethod + def _missing_(cls, value): + """Handle missing or unknown task status values. + + If a task status value is not found in the enum, this method returns + the Unknown status instead of raising an error. + + Args: + value: The status value that doesn't match any enum member + + Returns: + TaskStatus.Unknown: The default status for unrecognized values + """ + return cls.Unknown diff --git a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py index 5ab4d30ec..233fc2144 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py @@ -1,6 +1,6 @@ import time -from labelbox import DataRow, ExportTask, StreamType +from labelbox import DataRow, ExportTask, StreamType, Task, TaskStatus class TestExportDataRow: @@ -135,3 +135,33 @@ def test_cancel_export_task( # Verify the task was cancelled cancelled_task = client.get_task_by_id(export_task.uid) assert cancelled_task.status in ["CANCELING", "CANCELED"] + + def test_task_filter(self, client, data_row, wait_for_data_row_processing): + organization = client.get_organization() + user = client.get_user() + + export_task = DataRow.export( + client=client, + data_rows=[data_row], + task_name="TestExportDataRow:test_task_filter", + ) + + # Check if task is listed "in progress" in organization's tasks + org_tasks_in_progress = organization.tasks( + where=Task.status_as_enum == TaskStatus.In_Progress + ) + retrieved_task_in_progress = next( + (t for t in org_tasks_in_progress if t.uid == export_task.uid), "" + ) + assert getattr(retrieved_task_in_progress, "uid", "") == export_task.uid + + export_task.wait_till_done() + + # Check if task is listed "complete" in user's created tasks + user_tasks_complete = user.created_tasks( + where=Task.status_as_enum == TaskStatus.Complete + ) + retrieved_task_complete = next( + (t for t in user_tasks_complete if t.uid == export_task.uid), "" + ) + assert getattr(retrieved_task_complete, "uid", "") == export_task.uid