diff --git a/django_tasks/backends/database/backend.py b/django_tasks/backends/database/backend.py index d4de550..5a51ed9 100644 --- a/django_tasks/backends/database/backend.py +++ b/django_tasks/backends/database/backend.py @@ -23,12 +23,12 @@ class TaskResult(BaseTaskResult[T]): def refresh(self) -> None: self.db_result.refresh_from_db() - for attr, value in asdict(self.db_result.get_task_result()).items(): + for attr, value in asdict(self.db_result.task_result).items(): setattr(self, attr, value) async def arefresh(self) -> None: await self.db_result.arefresh_from_db() - for attr, value in asdict(self.db_result.get_task_result()).items(): + for attr, value in asdict(self.db_result.task_result).items(): setattr(self, attr, value) @@ -59,7 +59,7 @@ def enqueue( db_result.save() - return db_result.get_task_result() + return db_result.task_result async def aenqueue( self, task: Task[P, T], args: P.args, kwargs: P.kwargs @@ -70,13 +70,13 @@ async def aenqueue( await db_result.asave() - return db_result.get_task_result() + return db_result.task_result def get_result(self, result_id: str) -> TaskResult: from .models import DBTaskResult try: - return DBTaskResult.objects.get(id=result_id).get_task_result() + return DBTaskResult.objects.get(id=result_id).task_result except (DBTaskResult.DoesNotExist, ValidationError) as e: raise ResultDoesNotExist(result_id) from e @@ -84,6 +84,6 @@ async def aget_result(self, result_id: str) -> TaskResult: from .models import DBTaskResult try: - return (await DBTaskResult.objects.aget(id=result_id)).get_task_result() + return (await DBTaskResult.objects.aget(id=result_id)).task_result except (DBTaskResult.DoesNotExist, ValidationError) as e: raise ResultDoesNotExist(result_id) from e diff --git a/django_tasks/backends/database/migrations/0001_initial.py b/django_tasks/backends/database/migrations/0001_initial.py index 8cfa1ee..1593c52 100644 --- a/django_tasks/backends/database/migrations/0001_initial.py +++ b/django_tasks/backends/database/migrations/0001_initial.py @@ -1,5 +1,3 @@ -# Generated by Django 4.2.13 on 2024-05-24 10:46 - import uuid from django.db import migrations, models diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index 58c651e..15c6356 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -1,17 +1,20 @@ import uuid -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Generic, TypeVar from django.db import models -from django.utils.functional import cached_property from django.utils.module_loading import import_string +from typing_extensions import ParamSpec from django_tasks.task import ResultStatus, Task +T = TypeVar("T") +P = ParamSpec("P") + if TYPE_CHECKING: from .backend import TaskResult -class DBTaskResult(models.Model): +class DBTaskResult(Generic[P, T], models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) status = models.CharField( @@ -33,8 +36,8 @@ class DBTaskResult(models.Model): result = models.JSONField(default=None, null=True) - @cached_property - def task(self) -> Task: + @property + def task(self) -> Task[P, T]: task = import_string(self.task_path) assert isinstance(task, Task) @@ -46,10 +49,11 @@ def task(self) -> Task: backend=self.backend_name, ) - def get_task_result(self) -> "TaskResult": + @property + def task_result(self) -> "TaskResult[T]": from .backend import TaskResult - result = TaskResult[Any]( + result = TaskResult[T]( db_result=self, task=self.task, id=str(self.id),