Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

add regression test for dataclass as input to submitted tasks #130

Merged
merged 7 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions .github/workflows/add-to-project.yml

This file was deleted.

3 changes: 2 additions & 1 deletion .github/workflows/nightly-dev-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ jobs:
strategy:
matrix:
python-version:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
46 changes: 0 additions & 46 deletions .github/workflows/template-sync.yml

This file was deleted.

9 changes: 8 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Tests

on: [pull_request]
on:
pull_request:

push:
branches:
- main

jobs:
run-tests:
Expand All @@ -12,6 +17,8 @@ jobs:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
- "3.12"
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
31 changes: 5 additions & 26 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,8 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.2.1"
hooks:
- id: isort
- id: ruff
language_version: python3
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
- repo: https://github.com/econchick/interrogate
rev: 1.5.0
hooks:
- id: interrogate
args: [-vv]
pass_filenames: false
- repo: https://github.com/fsouza/autoflake8
rev: v0.3.2
hooks:
- id: autoflake8
language_version: python3
args: [
'--in-place',
]
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
- id: ruff-format
114 changes: 0 additions & 114 deletions MAINTAINERS.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/gen_examples_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def get_code_examples(obj: Union[ModuleType, Callable]) -> Set[str]:

code_examples_grouping = defaultdict(set)
for _, module_name, ispkg in iter_modules(prefect_dask.__path__):

module_nesting = f"{COLLECTION_SLUG}.{module_name}"
module_obj = load_module(module_nesting)

Expand Down
6 changes: 3 additions & 3 deletions prefect_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from contextlib import asynccontextmanager, contextmanager
from datetime import timedelta
from typing import Any, Dict, Optional, Union
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union

from distributed import Client, get_client
from prefect.context import FlowRunContext, TaskRunContext
Expand Down Expand Up @@ -51,7 +51,7 @@ def _generate_client_kwargs(
def get_dask_client(
timeout: Optional[Union[int, float, str, timedelta]] = None,
**client_kwargs: Dict[str, Any],
) -> Client:
) -> Generator[Client, None, None]:
"""
Yields a temporary synchronous dask client; this is useful
for parallelizing operations on dask collections,
Expand Down Expand Up @@ -108,7 +108,7 @@ def dask_flow():
async def get_async_dask_client(
timeout: Optional[Union[int, float, str, timedelta]] = None,
**client_kwargs: Dict[str, Any],
) -> Client:
) -> AsyncGenerator[Client, None]:
"""
Yields a temporary asynchronous dask client; this is useful
for parallelizing operations on dask collections,
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mkdocs-material
mkdocstrings[python]
isort
pre-commit
pytest-asyncio
pytest-asyncio >= 0.18.2, != 0.22.0, < 0.23.0
mock; python_version < '3.8'
mkdocs-gen-files
interrogate
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
anyio >= 3.7.1, < 4.0.0
prefect>=2.13.5
distributed==2022.2.0; python_version < '3.8'
distributed>=2022.5.0,!=2023.3.2,!=2023.3.2.1,!=2023.4.*,!=2023.5.*; python_version >= '3.8' # don't allow versions from 2023.3.2 to 2023.5 (inclusive) due to issue with get_client starting in 2023.3.2 (fixed in 2023.6.0) - https://github.com/dask/distributed/issues/7763
33 changes: 33 additions & 0 deletions tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import sys
from functools import partial
from typing import List
from uuid import uuid4

import cloudpickle
Expand Down Expand Up @@ -273,3 +274,35 @@ async def adapt(self, **kwargs):
)
async with task_runner.start():
assert task_runner._cluster._adapt_called

class TestInputArguments:
async def test_dataclasses_can_be_passed_to_task_runners(self, task_runner):
"""
this is a regression test for https://github.com/PrefectHQ/prefect/issues/6905
"""
from dataclasses import dataclass

@dataclass
class Foo:
value: int

@task
def get_dataclass_values(n: int):
return [Foo(value=i) for i in range(n)]

@task
def print_foo(x: Foo) -> Foo:
print(x)
return x

@flow(task_runner=task_runner)
def test_dask_flow(n: int = 3) -> List[Foo]:
foos = get_dataclass_values(n)
future = print_foo.submit(foos[0])
futures = print_foo.map(foos)

return [fut.result() for fut in futures + [future]]

results = test_dask_flow()

assert results == [Foo(value=i) for i in range(3)] + [Foo(value=0)]
22 changes: 18 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import dask
import pytest
from distributed import Client
Expand Down Expand Up @@ -79,8 +81,14 @@ def test_task():
def test_flow():
test_task.submit()

with pytest.raises(AttributeError, match="__enter__"):
test_flow()
if sys.version_info < (3, 11):
with pytest.raises(AttributeError, match="__enter__"):
test_flow()
else:
with pytest.raises(
TypeError, match="not support the context manager protocol"
):
test_flow()

async def test_from_flow(self):
@flow(task_runner=DaskTaskRunner)
Expand All @@ -99,8 +107,14 @@ def test_flow():
with get_async_dask_client():
pass

with pytest.raises(AttributeError, match="__enter__"):
test_flow()
if sys.version_info < (3, 11):
with pytest.raises(AttributeError, match="__enter__"):
test_flow()
else:
with pytest.raises(
TypeError, match="not support the context manager protocol"
):
test_flow()

async def test_outside_run_context(self):
delayed_num = dask.delayed(42)
Expand Down
Loading
Loading