diff --git a/src/prefect/orion/database/migrations/versions/postgresql/2023_01_31_110543_f98ae6d8e2cc_work_queue_data_migration.py b/src/prefect/orion/database/migrations/versions/postgresql/2023_01_31_110543_f98ae6d8e2cc_work_queue_data_migration.py new file mode 100644 index 000000000000..625c5adffd17 --- /dev/null +++ b/src/prefect/orion/database/migrations/versions/postgresql/2023_01_31_110543_f98ae6d8e2cc_work_queue_data_migration.py @@ -0,0 +1,213 @@ +"""Work queue data migration + +Revision ID: f98ae6d8e2cc +Revises: 0a1250a5aa25 +Create Date: 2023-01-31 11:05:43.356002 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f98ae6d8e2cc" +down_revision = "0a1250a5aa25" +branch_labels = None +depends_on = None + + +def upgrade(): + # Create temporary indexes for migration + op.execute( + "CREATE INDEX IF NOT EXISTS ix_flow_run__work_queue_id_work_queue_name ON flow_run (work_queue_id, work_queue_name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_deployment__work_queue_id_work_queue_name ON deployment (work_queue_id, work_queue_name)" + ) + + # Create default agent work pool and associate all existing queues with it + connection = op.get_bind() + + connection.execute( + sa.text( + "INSERT INTO work_pool (name, type) VALUES ('default-agent-pool', 'prefect-agent')" + ) + ) + + default_pool_id = connection.execute( + sa.text("SELECT id FROM work_pool WHERE name = 'default-agent-pool'") + ).fetchone()[0] + + default_queue = connection.execute( + sa.text("SELECT id FROM work_queue WHERE name = 'default'") + ).fetchone() + + if not default_queue: + connection.execute( + sa.text( + f"INSERT INTO work_queue (name, work_pool_id) VALUES ('default', :default_pool_id)" + ).params({"default_pool_id": default_pool_id}), + ) + + connection.execute( + sa.text( + "UPDATE work_queue SET work_pool_id = :default_pool_id WHERE work_pool_id IS NULL" + ).params({"default_pool_id": default_pool_id}), + ) + + default_queue_id = connection.execute( + sa.text( + "SELECT id FROM work_queue WHERE name = 'default' and work_pool_id = :default_pool_id" + ).params({"default_pool_id": default_pool_id}), + ).fetchone()[0] + + connection.execute( + sa.text( + "UPDATE work_pool SET default_queue_id = :default_queue_id WHERE id = :default_pool_id" + ).params( + {"default_pool_id": default_pool_id, "default_queue_id": default_queue_id} + ), + ) + + # Set priority on all queues and update flow runs and deployments + queue_rows = connection.execute( + sa.text( + "SELECT id, name FROM work_queue WHERE work_pool_id = :default_pool_id" + ).params({"default_pool_id": default_pool_id}), + ).fetchall() + + with op.get_context().autocommit_block(): + for enumeration, row in enumerate(queue_rows): + connection.execute( + sa.text( + "UPDATE work_queue SET priority = :priority WHERE id = :id" + ).params({"priority": enumeration + 1, "id": row[0]}), + ) + + batch_size = 250 + + while True: + result = connection.execute( + sa.text( + """ + UPDATE flow_run + SET work_queue_id=:id + WHERE flow_run.id in ( + SELECT id + FROM flow_run + WHERE flow_run.work_queue_id IS NULL and flow_run.work_queue_name=:name + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "name": row[1], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + while True: + result = connection.execute( + sa.text( + """ + UPDATE deployment + SET work_queue_id=:id + WHERE deployment.id in ( + SELECT id + FROM deployment + WHERE deployment.work_queue_id IS NULL and deployment.work_queue_name=:name + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "name": row[1], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.drop_constraint("uq_work_queue__name") + batch_op.create_unique_constraint( + op.f("uq_work_queue__work_pool_id_name"), ["work_pool_id", "name"] + ) + batch_op.alter_column("work_pool_id", nullable=False) + + op.execute("DROP INDEX IF EXISTS ix_flow_run__work_queue_id_work_queue_name") + op.execute("DROP INDEX IF EXISTS ix_deployment__work_queue_id_work_queue_name") + + +def downgrade(): + # Create temporary indexes for migration + op.execute( + "CREATE INDEX IF NOT EXISTS ix_flow_run__work_queue_id_work_queue_name ON flow_run (work_queue_id, work_queue_name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_deployment__work_queue_id_work_queue_name ON deployment (work_queue_id, work_queue_name)" + ) + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.alter_column("work_pool_id", nullable=True) + + connection = op.get_bind() + + # Delete all non-default queues and pools + default_pool_id_result = connection.execute( + sa.text("SELECT id FROM work_pool WHERE name = 'default-agent-pool'") + ).fetchone() + if default_pool_id_result: + default_pool_id = default_pool_id_result[0] + connection.execute( + sa.text( + "DELETE FROM work_queue WHERE work_pool_id != :default_pool_id" + ).params({"default_pool_id": default_pool_id}) + ) + queue_rows = connection.execute( + sa.text("SELECT id, name FROM work_queue"), + ).fetchall() + + with op.get_context().autocommit_block(): + for row in queue_rows: + batch_size = 250 + + while True: + result = connection.execute( + sa.text( + """ + UPDATE flow_run + SET work_queue_id=NULL + WHERE flow_run.id in ( + SELECT id + FROM flow_run + WHERE flow_run.work_queue_id IS NOT NULL and flow_run.work_queue_id=:id + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + while True: + result = connection.execute( + sa.text( + """ + UPDATE deployment + SET work_queue_id=NULL + WHERE deployment.id in ( + SELECT id + FROM deployment + WHERE deployment.work_queue_id IS NOT NULL and deployment.work_queue_id=:id + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + connection.execute(sa.text("UPDATE work_queue SET work_pool_id = NULL")) + + connection.execute(sa.text("DELETE FROM work_pool")) + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.drop_constraint("uq_work_queue__work_pool_id_name") + batch_op.create_unique_constraint("uq_work_queue__name", ["name"]) + + op.execute("DROP INDEX IF EXISTS ix_flow_run__work_queue_id_work_queue_name") + op.execute("DROP INDEX IF EXISTS ix_deployment__work_queue_id_work_queue_name") diff --git a/src/prefect/orion/database/migrations/versions/sqlite/2023_01_31_105442_1678f2fb8b33_work_queue_data_migration.py b/src/prefect/orion/database/migrations/versions/sqlite/2023_01_31_105442_1678f2fb8b33_work_queue_data_migration.py new file mode 100644 index 000000000000..0c57693666e0 --- /dev/null +++ b/src/prefect/orion/database/migrations/versions/sqlite/2023_01_31_105442_1678f2fb8b33_work_queue_data_migration.py @@ -0,0 +1,220 @@ +"""Work queue data migration + +Revision ID: 1678f2fb8b33 +Revises: b9bda9f142f1 +Create Date: 2023-01-31 10:54:42.747849 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "1678f2fb8b33" +down_revision = "b9bda9f142f1" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("PRAGMA foreign_keys=OFF") + + # Create temporary indexes for migration + op.execute( + "CREATE INDEX IF NOT EXISTS ix_flow_run__work_queue_id_work_queue_name ON flow_run (work_queue_id, work_queue_name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_deployment__work_queue_id_work_queue_name ON deployment (work_queue_id, work_queue_name)" + ) + + # Create default agent work pool and associate all existing queues with it + connection = op.get_bind() + + connection.execute( + sa.text( + "INSERT INTO work_pool (name, type) VALUES ('default-agent-pool', 'prefect-agent')" + ) + ) + + default_pool_id = connection.execute( + sa.text("SELECT id FROM work_pool WHERE name = 'default-agent-pool'") + ).fetchone()[0] + + default_queue = connection.execute( + sa.text("SELECT id FROM work_queue WHERE name = 'default'") + ).fetchone() + + if not default_queue: + connection.execute( + sa.text( + f"INSERT INTO work_queue (name, work_pool_id) VALUES ('default', :default_pool_id)" + ).params({"default_pool_id": default_pool_id}), + ) + + connection.execute( + sa.text( + "UPDATE work_queue SET work_pool_id = :default_pool_id WHERE work_pool_id IS NULL" + ).params({"default_pool_id": default_pool_id}), + ) + + default_queue_id = connection.execute( + sa.text( + "SELECT id FROM work_queue WHERE name = 'default' and work_pool_id = :default_pool_id" + ).params({"default_pool_id": default_pool_id}), + ).fetchone()[0] + + connection.execute( + sa.text( + "UPDATE work_pool SET default_queue_id = :default_queue_id WHERE id = :default_pool_id" + ).params( + {"default_pool_id": default_pool_id, "default_queue_id": default_queue_id} + ), + ) + + # Set priority on all queues and update flow runs and deployments + queue_rows = connection.execute( + sa.text( + "SELECT id, name FROM work_queue WHERE work_pool_id = :default_pool_id" + ).params({"default_pool_id": default_pool_id}), + ).fetchall() + + with op.get_context().autocommit_block(): + for enumeration, row in enumerate(queue_rows): + connection.execute( + sa.text( + "UPDATE work_queue SET priority = :priority WHERE id = :id" + ).params({"priority": enumeration + 1, "id": row[0]}), + ) + + batch_size = 250 + + while True: + result = connection.execute( + sa.text( + """ + UPDATE flow_run + SET work_queue_id=:id + WHERE flow_run.id in ( + SELECT id + FROM flow_run + WHERE flow_run.work_queue_id IS NULL and flow_run.work_queue_name=:name + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "name": row[1], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + while True: + result = connection.execute( + sa.text( + """ + UPDATE deployment + SET work_queue_id=:id + WHERE deployment.id in ( + SELECT id + FROM deployment + WHERE deployment.work_queue_id IS NULL and deployment.work_queue_name=:name + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "name": row[1], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.drop_constraint("uq_work_queue__name") + batch_op.create_unique_constraint( + op.f("uq_work_queue__work_pool_id_name"), ["work_pool_id", "name"] + ) + batch_op.alter_column("work_pool_id", nullable=False) + + op.execute("DROP INDEX IF EXISTS ix_flow_run__work_queue_id_work_queue_name") + op.execute("DROP INDEX IF EXISTS ix_deployment__work_queue_id_work_queue_name") + + op.execute("PRAGMA foreign_keys=ON") + + +def downgrade(): + op.execute("PRAGMA foreign_keys=OFF") + + connection = op.get_bind() + # Create temporary indexes for migration + op.execute( + "CREATE INDEX IF NOT EXISTS ix_flow_run__work_queue_id_work_queue_name ON flow_run (work_queue_id, work_queue_name)" + ) + op.execute( + "CREATE INDEX IF NOT EXISTS ix_deployment__work_queue_id_work_queue_name ON deployment (work_queue_id, work_queue_name)" + ) + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.alter_column("work_pool_id", nullable=True) + + # Delete all non-default queues and pools + default_pool_id_result = connection.execute( + sa.text("SELECT id FROM work_pool WHERE name = 'default-agent-pool'") + ).fetchone() + if default_pool_id_result: + default_pool_id = default_pool_id_result[0] + connection.execute( + sa.text( + "DELETE FROM work_queue WHERE work_pool_id != :default_pool_id" + ).params({"default_pool_id": default_pool_id}) + ) + queue_rows = connection.execute( + sa.text("SELECT id, name FROM work_queue"), + ).fetchall() + + with op.get_context().autocommit_block(): + for row in queue_rows: + batch_size = 250 + + while True: + result = connection.execute( + sa.text( + """ + UPDATE flow_run + SET work_queue_id=NULL + WHERE flow_run.id in ( + SELECT id + FROM flow_run + WHERE flow_run.work_queue_id IS NOT NULL and flow_run.work_queue_id=:id + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + while True: + result = connection.execute( + sa.text( + """ + UPDATE deployment + SET work_queue_id=NULL + WHERE deployment.id in ( + SELECT id + FROM deployment + WHERE deployment.work_queue_id IS NOT NULL and deployment.work_queue_id=:id + LIMIT :batch_size + ) + """ + ).params({"id": row[0], "batch_size": batch_size}), + ) + if result.rowcount <= batch_size: + break + + connection.execute(sa.text("UPDATE work_queue SET work_pool_id = NULL")) + + connection.execute(sa.text("DELETE FROM work_pool")) + + with op.batch_alter_table("work_queue", schema=None) as batch_op: + batch_op.drop_constraint("uq_work_queue__work_pool_id_name") + batch_op.create_unique_constraint("uq_work_queue__name", ["name"]) + + op.execute("DROP INDEX IF EXISTS ix_flow_run__work_queue_id_work_queue_name") + op.execute("DROP INDEX IF EXISTS ix_deployment__work_queue_id_work_queue_name") + + op.execute("PRAGMA foreign_keys=ON") diff --git a/src/prefect/orion/models/work_queues.py b/src/prefect/orion/models/work_queues.py index 5ad2a20b261d..f59e87287aba 100644 --- a/src/prefect/orion/models/work_queues.py +++ b/src/prefect/orion/models/work_queues.py @@ -47,6 +47,14 @@ async def create_work_queue( ) if default_agent_work_pool: data["work_pool_id"] = default_agent_work_pool.id + else: + default_agent_work_pool = await models.workers.create_work_pool( + session=session, + work_pool=schemas.actions.WorkPoolCreate( + name=DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent" + ), + ) + data["work_pool_id"] = default_agent_work_pool.id if data.get("work_pool_id"): max_priority_query = sa.select( @@ -340,11 +348,10 @@ async def _ensure_work_queue_exists( ) else: if name != "default": - work_queue = await models.work_queues.create_work_queue( + work_queue = await models.workers.create_work_queue( session=session, - work_queue=schemas.actions.WorkQueueCreate( - name=name, priority=1, work_pool_id=default_pool.id - ), + work_pool_id=default_pool.id, + work_queue=schemas.actions.WorkQueueCreate(name=name, priority=1), ) else: work_queue = await models.work_queues.read_work_queue( diff --git a/tests/cli/deployment/test_deployment_build.py b/tests/cli/deployment/test_deployment_build.py index 7bb0bb0145ab..10daed1ae116 100644 --- a/tests/cli/deployment/test_deployment_build.py +++ b/tests/cli/deployment/test_deployment_build.py @@ -7,6 +7,8 @@ import pendulum import pytest +import prefect.orion.models as models +import prefect.orion.schemas as schemas from prefect import flow from prefect._internal.compatibility.experimental import ExperimentalFeature from prefect.deployments import Deployment @@ -132,6 +134,24 @@ def mock_build_from_flow(monkeypatch): return mock_build_from_flow +@pytest.fixture(autouse=True) +async def ensure_default_agent_pool_exists(session): + # The default agent work pool is created by a migration, but is cleared on + # consecutive test runs. This fixture ensures that the default agent work + # pool exists before each test. + default_work_pool = await models.workers.read_work_pool_by_name( + session=session, work_pool_name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME + ) + if default_work_pool is None: + await models.workers.create_work_pool( + session=session, + work_pool=schemas.actions.WorkPoolCreate( + name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent" + ), + ) + await session.commit() + + class TestSchedules: def test_passing_cron_schedules_to_build(self, patch_import, tmp_path): invoke_and_assert( diff --git a/tests/experimental/workers/test_base_worker.py b/tests/experimental/workers/test_base_worker.py index 7cc7dbd6eb49..0797df039172 100644 --- a/tests/experimental/workers/test_base_worker.py +++ b/tests/experimental/workers/test_base_worker.py @@ -6,6 +6,7 @@ import pytest from pydantic import Field +import prefect.orion.schemas as schemas from prefect.client.orion import OrionClient, get_client from prefect.deployments import Deployment from prefect.exceptions import ObjectNotFound @@ -16,7 +17,6 @@ ) from prefect.flows import flow from prefect.orion import models -from prefect.orion.schemas.core import WorkPool from prefect.settings import ( PREFECT_EXPERIMENTAL_ENABLE_WORK_POOLS, PREFECT_EXPERIMENTAL_ENABLE_WORKERS, @@ -54,6 +54,24 @@ def auto_enable_work_pools(enable_work_pools): assert PREFECT_EXPERIMENTAL_ENABLE_WORK_POOLS +@pytest.fixture(autouse=True) +async def ensure_default_agent_pool_exists(session): + # The default agent work pool is created by a migration, but is cleared on + # consecutive test runs. This fixture ensures that the default agent work + # pool exists before each test. + default_work_pool = await models.workers.read_work_pool_by_name( + session=session, work_pool_name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME + ) + if default_work_pool is None: + await models.workers.create_work_pool( + session=session, + work_pool=schemas.actions.WorkPoolCreate( + name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent" + ), + ) + await session.commit() + + async def test_worker_creates_workflows_directory_during_setup(tmp_path: Path): await WorkerTestImpl( name="test", @@ -421,7 +439,7 @@ class WorkerJobConfig(BaseJobConfiguration): response = await client.post( "/experimental/work_pools/", json=dict(name=pool_name, type="test-type") ) - result = pydantic.parse_obj_as(WorkPool, response.json()) + result = pydantic.parse_obj_as(schemas.core.WorkPool, response.json()) model = await models.workers.read_work_pool(session=session, work_pool_id=result.id) assert model.name == pool_name @@ -475,7 +493,7 @@ class WorkerVariables(BaseVariables): response = await client.post( "/experimental/work_pools/", json=dict(name=pool_name, type="test-type") ) - result = pydantic.parse_obj_as(WorkPool, response.json()) + result = pydantic.parse_obj_as(schemas.core.WorkPool, response.json()) model = await models.workers.read_work_pool(session=session, work_pool_id=result.id) assert model.name == pool_name diff --git a/tests/orion/api/test_deployments.py b/tests/orion/api/test_deployments.py index e65f156051c4..fccf6c7630f1 100644 --- a/tests/orion/api/test_deployments.py +++ b/tests/orion/api/test_deployments.py @@ -1276,12 +1276,28 @@ async def test_well_formed_response( connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() dialect = get_dialect(connection_url) - assert len(response.json()) == 2 - - q1, q2 = response.json() - assert {q1["name"], q2["name"]} == { - "First", - "Second", - } - assert set(q1["filter"]["tags"] + q2["filter"]["tags"]) == {"a", "b"} - assert q1["filter"]["deployment_ids"] == q2["filter"]["deployment_ids"] == None + if dialect.name == "postgresql": + + assert len(response.json()) == 2 + + q1, q2 = response.json() + assert {q1["name"], q2["name"]} == {"First", "Second"} + assert set(q1["filter"]["tags"] + q2["filter"]["tags"]) == {"a", "b"} + assert ( + q1["filter"]["deployment_ids"] == q2["filter"]["deployment_ids"] == None + ) + + else: + # sqlite picks up the default queue because it has no filter + assert len(response.json()) == 3 + + q1, q2, q3 = response.json() + assert {q1["name"], q2["name"], q3["name"]} == { + "First", + "Second", + "default", + } + assert set(q2["filter"]["tags"] + q3["filter"]["tags"]) == {"a", "b"} + assert ( + q2["filter"]["deployment_ids"] == q3["filter"]["deployment_ids"] == None + ) diff --git a/tests/orion/api/test_work_queues.py b/tests/orion/api/test_work_queues.py index 1a27d44bbb47..865e647f0c3e 100644 --- a/tests/orion/api/test_work_queues.py +++ b/tests/orion/api/test_work_queues.py @@ -175,7 +175,7 @@ async def test_read_work_queues(self, work_queues, client): response = await client.post("/work_queues/filter") assert response.status_code == status.HTTP_200_OK # includes default work queue - assert len(response.json()) == 3 + assert len(response.json()) == 4 async def test_read_work_queues_applies_limit(self, work_queues, client): response = await client.post("/work_queues/filter", json=dict(limit=1)) @@ -185,10 +185,11 @@ async def test_read_work_queues_applies_limit(self, work_queues, client): async def test_read_work_queues_offset(self, work_queues, client, session): response = await client.post("/work_queues/filter", json=dict(offset=1)) assert response.status_code == status.HTTP_200_OK - assert len(response.json()) == 2 + assert len(response.json()) == 3 # ordered by name by default - assert response.json()[0]["name"] == "wq-1 Y" - assert response.json()[1]["name"] == "wq-2 Y" + assert response.json()[0]["name"] == "wq-1 X" + assert response.json()[1]["name"] == "wq-1 Y" + assert response.json()[2]["name"] == "wq-2 Y" async def test_read_work_queues_by_name(self, work_queues, client, session): response = await client.post( diff --git a/tests/orion/api/test_workers.py b/tests/orion/api/test_workers.py index 354c1747fced..403caeed2042 100644 --- a/tests/orion/api/test_workers.py +++ b/tests/orion/api/test_workers.py @@ -377,9 +377,6 @@ async def test_read_invalid_config(self, client): assert response.status_code == status.HTTP_404_NOT_FOUND -@pytest.mark.skip( - reason="Need unique constraint for work_queue on work_pool_id and name" -) class TestReadWorkPools: @pytest.fixture(autouse=True) async def create_work_pools(self, client): @@ -533,9 +530,6 @@ async def test_heartbeat_worker_limit(self, client, work_pool): assert workers_response.json()[0]["name"] == "another-worker" -@pytest.mark.skip( - reason="Need unique constraint for work_queue on work_pool_id and name" -) class TestGetScheduledRuns: @pytest.fixture(autouse=True) async def setup(self, session, flow): diff --git a/tests/orion/database/test_migrations.py b/tests/orion/database/test_migrations.py index 25a6eda08dc1..3fa9e60130cb 100644 --- a/tests/orion/database/test_migrations.py +++ b/tests/orion/database/test_migrations.py @@ -258,3 +258,181 @@ async def test_adding_work_pool_tables_does_not_remove_fks(db, flow): finally: await run_sync_in_worker_thread(alembic_upgrade) + + +async def test_adding_default_agent_pool_with_existing_default_queue_migration( + db, flow +): + connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() + dialect = get_dialect(connection_url) + + # get the proper migration revisions + if dialect.name == "postgresql": + revisions = ("0a1250a5aa25", "f98ae6d8e2cc") + else: + revisions = ("b9bda9f142f1", "1678f2fb8b33") + + try: + await run_sync_in_worker_thread(alembic_downgrade, revision=revisions[0]) + + session = await db.session() + async with session: + # clear the work queue table + await session.execute(sa.text("DELETE FROM work_queue;")) + await session.commit() + + # insert some work queues into the database + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('default');") + ) + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('queue-1');") + ) + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('queue-2');") + ) + await session.commit() + + # Insert a flow run and deployment to check if they are correctly assigned a work queue ID + flow_run_id = uuid4() + await session.execute( + sa.text( + f"INSERT INTO flow_run (id, name, flow_id, work_queue_name) values ('{flow_run_id}', 'foo', '{flow.id}', 'queue-1');" + ) + ) + await session.execute( + sa.text( + f"INSERT INTO deployment (name, flow_id, work_queue_name) values ('my-deployment', '{flow.id}', 'queue-1');" + ) + ) + await session.commit() + + async with session: + # Confirm the work queues are present + pre_work_queue_ids = ( + await session.execute(sa.text("SELECT id FROM work_queue;")) + ).fetchall() + + assert len(pre_work_queue_ids) == 3 + + # run the migration + await run_sync_in_worker_thread(alembic_upgrade, revision=revisions[1]) + + session = await db.session() + async with session: + # Check that work queues are assigned to the default agent pool + default_pool_id = ( + await session.execute( + sa.text( + "SELECT id FROM work_pool WHERE name = 'default-agent-pool';" + ) + ) + ).scalar() + + work_queue_ids = ( + await session.execute( + sa.text( + f"SELECT id FROM work_queue WHERE work_pool_id = '{default_pool_id}';" + ) + ) + ).fetchall() + + assert len(work_queue_ids) == 3 + assert set(work_queue_ids) == set(pre_work_queue_ids) + + # Check that the flow run and deployment are assigned to the correct work queue + queue_1 = ( + await session.execute( + sa.text("SELECT id FROM work_queue WHERE name = 'queue-1';") + ) + ).fetchone() + flow_run = ( + await session.execute( + sa.text( + f"SELECT work_queue_id FROM flow_run WHERE id = '{flow_run_id}';" + ) + ) + ).fetchone() + deployment = ( + await session.execute( + sa.text( + f"SELECT work_queue_id FROM deployment WHERE name = 'my-deployment';" + ) + ) + ).fetchone() + + assert queue_1[0] == flow_run[0] + assert queue_1[0] == deployment[0] + + finally: + await run_sync_in_worker_thread(alembic_upgrade) + + +async def test_adding_default_agent_pool_without_existing_default_queue_migration(db): + connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() + dialect = get_dialect(connection_url) + + # get the proper migration revisions + if dialect.name == "postgresql": + revisions = ("0a1250a5aa25", "f98ae6d8e2cc") + else: + revisions = ("b9bda9f142f1", "1678f2fb8b33") + + try: + await run_sync_in_worker_thread(alembic_downgrade, revision=revisions[0]) + + session = await db.session() + async with session: + # clear the work queue table + await session.execute(sa.text("DELETE FROM work_queue;")) + await session.commit() + + # insert some work queues into the database + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('queue-1');") + ) + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('queue-2');") + ) + await session.execute( + sa.text("INSERT INTO work_queue (name) values ('queue-3');") + ) + await session.commit() + + async with session: + # Confirm the work queues are present + pre_work_queue_names = ( + await session.execute(sa.text("SELECT name FROM work_queue;")) + ).fetchall() + + assert len(pre_work_queue_names) == 3 + + # run the migration + await run_sync_in_worker_thread(alembic_upgrade, revision=revisions[1]) + + session = await db.session() + async with session: + # Check that work queues are assigned to the default agent pool + default_pool_id = ( + await session.execute( + sa.text( + "SELECT id FROM work_pool WHERE name = 'default-agent-pool';" + ) + ) + ).scalar() + + work_queue_names = ( + await session.execute( + sa.text( + f"SELECT name FROM work_queue WHERE work_pool_id = '{default_pool_id}';" + ) + ) + ).fetchall() + + assert len(work_queue_names) == 4 + assert set(work_queue_names) == set(pre_work_queue_names).union( + [("default",)] + ) + + finally: + await run_sync_in_worker_thread(alembic_upgrade) diff --git a/tests/orion/database/test_queries.py b/tests/orion/database/test_queries.py index 49993e15c4cc..8a1030d470b4 100644 --- a/tests/orion/database/test_queries.py +++ b/tests/orion/database/test_queries.py @@ -224,9 +224,6 @@ async def test_query_skips_locked(self, db): assert len(result2) == 0 -@pytest.mark.skip( - reason="Need unique constraint for work_queue on work_pool_id and name" -) class TestGetRunsFromWorkQueueQuery: @pytest.fixture(autouse=True) async def setup(self, session, flow): diff --git a/tests/orion/models/deprecated/test_work_queues.py b/tests/orion/models/deprecated/test_work_queues.py index f4ae12a306ff..a734e7b07919 100644 --- a/tests/orion/models/deprecated/test_work_queues.py +++ b/tests/orion/models/deprecated/test_work_queues.py @@ -623,7 +623,14 @@ async def test_no_tag_picks_up_only_number_of_expected_queues( session=session, deployment_id=match_id ) - assert len(actual_queues) == 3 + connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() + dialect = get_dialect(connection_url) + + if dialect.name == "postgresql": + assert len(actual_queues) == 3 + else: + # sqlite picks up the default queue because it has no filter + assert len(actual_queues) == 4 # ONE TAG DEPLOYMENTS with no-tag queues async def test_one_tag_picks_up_no_filter_q(self, session, flow, flow_function): @@ -679,7 +686,11 @@ async def test_one_tag_picks_up_only_number_of_expected_queues( connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() dialect = get_dialect(connection_url) - assert len(actual_queues) == 6 + if dialect.name == "postgresql": + assert len(actual_queues) == 6 + else: + # sqlite picks up the default queue because it has no filter + assert len(actual_queues) == 7 # TWO TAG DEPLOYMENTS with no-tag queues async def test_two_tag_picks_up_no_filter_q(self, session, flow, flow_function): @@ -755,4 +766,8 @@ async def test_two_tag_picks_up_only_number_of_expected_queues( connection_url = PREFECT_ORION_DATABASE_CONNECTION_URL.value() dialect = get_dialect(connection_url) - assert len(actual_queues) == 9 + if dialect.name == "postgresql": + assert len(actual_queues) == 9 + else: + # sqlite picks up the default queue because it has no filter + assert len(actual_queues) == 10 diff --git a/tests/orion/models/test_work_queues.py b/tests/orion/models/test_work_queues.py index 9b7e1fc21900..438c73013f6c 100644 --- a/tests/orion/models/test_work_queues.py +++ b/tests/orion/models/test_work_queues.py @@ -44,19 +44,6 @@ async def test_create_work_queue_throws_exception_on_name_conflict( ), ) - async def test_create_work_queue_throws_exception_on_name_conflict( - self, - session, - work_queue, - ): - with pytest.raises(IntegrityError): - await models.work_queues.create_work_queue( - session=session, - work_queue=schemas.actions.WorkQueueCreate( - name=work_queue.name, - ), - ) - class TestReadWorkQueue: async def test_read_work_queue_by_id(self, session, work_queue): @@ -121,11 +108,11 @@ async def work_queues(self, session): async def test_read_work_queue(self, work_queues, session): read_work_queue = await models.work_queues.read_work_queues(session=session) - assert len(read_work_queue) == len(work_queues) + assert len(read_work_queue) == len(work_queues) + 1 # +1 for default queue async def test_read_work_queue_applies_limit(self, work_queues, session): read_work_queue = await models.work_queues.read_work_queues( - session=session, limit=1 + session=session, limit=1, offset=1 ) assert {queue.id for queue in read_work_queue} == {work_queues[0].id} @@ -134,6 +121,7 @@ async def test_read_work_queue_applies_offset(self, work_queues, session): session=session, offset=1 ) assert {queue.id for queue in read_work_queue} == { + work_queues[0].id, work_queues[1].id, work_queues[2].id, work_queues[3].id, diff --git a/tests/orion/models/test_workers.py b/tests/orion/models/test_workers.py index 00acf0ba00c6..7d9a8adcbe6e 100644 --- a/tests/orion/models/test_workers.py +++ b/tests/orion/models/test_workers.py @@ -633,9 +633,6 @@ async def test_multiple_worker_heartbeats(self, session, work_pool): assert processes[2].name == "X" -@pytest.mark.skip( - reason="Need unique constraint for work_queue on work_pool_id and name" -) class TestGetScheduledRuns: @pytest.fixture(autouse=True) async def setup(self, session, flow): diff --git a/tests/test_deployments.py b/tests/test_deployments.py index 2cdcee8e81dc..bd42545f9e8e 100644 --- a/tests/test_deployments.py +++ b/tests/test_deployments.py @@ -8,6 +8,8 @@ from httpx import Response from pydantic.error_wrappers import ValidationError +import prefect.orion.models as models +import prefect.orion.schemas as schemas from prefect import flow, task from prefect.blocks.core import Block from prefect.blocks.fields import SecretDict @@ -22,6 +24,24 @@ from prefect.utilities.slugify import slugify +@pytest.fixture(autouse=True) +async def ensure_default_agent_pool_exists(session): + # The default agent work pool is created by a migration, but is cleared on + # consecutive test runs. This fixture ensures that the default agent work + # pool exists before each test. + default_work_pool = await models.workers.read_work_pool_by_name( + session=session, work_pool_name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME + ) + if default_work_pool is None: + await models.workers.create_work_pool( + session=session, + work_pool=schemas.actions.WorkPoolCreate( + name=models.workers.DEFAULT_AGENT_WORK_POOL_NAME, type="prefect-agent" + ), + ) + await session.commit() + + class TestDeploymentBasicInterface: async def test_that_name_is_required(self): with pytest.raises(ValidationError, match="field required"):