Skip to content

Commit

Permalink
feat(diracx-router): include pagination to search
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr authored and chrisburr committed Jun 5, 2024
1 parent c43d102 commit 73647d6
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 56 deletions.
23 changes: 17 additions & 6 deletions diracx-db/src/diracx/db/sql/jobs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ async def summary(self, group_by, search) -> list[dict[str, str | int]]:

async def search(
self,
parameters: list[str],
parameters: list[str] | None,
search: list[SearchSpec],
sorts: list[SortSpec],
*,
distinct: bool = False,
per_page: int = 100,
page: int | None = None,
) -> list[dict[str, Any]]:
) -> tuple[int, list[dict[Any, Any]]]:
# Find which columns to select
columns = _get_columns(Jobs.__table__, parameters)
stmt = select(*columns)
Expand Down Expand Up @@ -123,12 +123,23 @@ async def search(
if distinct:
stmt = stmt.distinct()

# Calculate total count before applying pagination
total_count_subquery = stmt.alias()
total_count_stmt = select(func.count()).select_from(total_count_subquery)
total = (await self.conn.execute(total_count_stmt)).scalar_one()

# Apply pagination
if page:
raise NotImplementedError("TODO Not yet implemented")
if page is not None:
if page < 1:
raise InvalidQueryError("Page must be a positive integer")
if per_page < 1:
raise InvalidQueryError("Per page must be a positive integer")
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

# Execute the query
return [dict(row._mapping) async for row in (await self.conn.stream(stmt))]
return total, [
dict(row._mapping) async for row in (await self.conn.stream(stmt))
]

async def _insertNewJDL(self, jdl) -> int:
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
Expand Down Expand Up @@ -323,7 +334,7 @@ async def rescheduleJob(self, job_id) -> dict[str, Any]:
from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
from DIRAC.Core.Utilities.ReturnValues import SErrorException

result = await self.search(
_, result = await self.search(
parameters=[
"Status",
"MinorStatus",
Expand Down
2 changes: 1 addition & 1 deletion diracx-db/src/diracx/db/sql/jobs/status_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def set_job_status(
for key, value in status.items():
statusDict[key] = {k: v for k, v in value.dict().items() if v is not None}

res = await job_db.search(
_, res = await job_db.search(
parameters=["Status", "StartExecTime", "EndExecTime"],
search=[
{
Expand Down
117 changes: 96 additions & 21 deletions diracx-db/tests/jobs/test_jobDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ async def job_db(tmp_path):
async def test_search_parameters(job_db):
"""Test that we can search specific parameters for jobs in the database."""
async with job_db as job_db:
result = await job_db.search(["JobID"], [], [])
total, result = await job_db.search(["JobID"], [], [])
assert total == 0
assert not result

result = await asyncio.gather(
Expand All @@ -50,26 +51,34 @@ async def test_search_parameters(job_db):

async with job_db as job_db:
# Search a specific parameter: JobID
result = await job_db.search(["JobID"], [], [])
total, result = await job_db.search(["JobID"], [], [])
assert total == 100
assert result
for r in result:
assert r.keys() == {"JobID"}

# Search a specific parameter: Status
result = await job_db.search(["Status"], [], [])
total, result = await job_db.search(["Status"], [], [])
assert total == 100
assert result
for r in result:
assert r.keys() == {"Status"}

# Search for multiple parameters: JobID, Status
result = await job_db.search(["JobID", "Status"], [], [])
total, result = await job_db.search(["JobID", "Status"], [], [])
assert total == 100
assert result
for r in result:
assert r.keys() == {"JobID", "Status"}

# Search for a specific parameter but use distinct: Status
total, result = await job_db.search(["Status"], [], [], distinct=True)
assert total == 1
assert result

# Search for a non-existent parameter: Dummy
with pytest.raises(InvalidQueryError):
result = await job_db.search(["Dummy"], [], [])
total, result = await job_db.search(["Dummy"], [], [])


async def test_search_conditions(job_db):
Expand All @@ -94,7 +103,8 @@ async def test_search_conditions(job_db):
condition = ScalarSearchSpec(
parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=3
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert total == 1
assert result
assert len(result) == 1
assert result[0]["JobID"] == 3
Expand All @@ -103,7 +113,8 @@ async def test_search_conditions(job_db):
condition = ScalarSearchSpec(
parameter="JobID", operator=ScalarSearchOperator.LESS_THAN, value=3
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert total == 2
assert result
assert len(result) == 2
assert result[0]["JobID"] == 1
Expand All @@ -113,7 +124,8 @@ async def test_search_conditions(job_db):
condition = ScalarSearchSpec(
parameter="JobID", operator=ScalarSearchOperator.NOT_EQUAL, value=3
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert total == 99
assert result
assert len(result) == 99
assert all(r["JobID"] != 3 for r in result)
Expand All @@ -122,14 +134,15 @@ async def test_search_conditions(job_db):
condition = ScalarSearchSpec(
parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=5873
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert not result

# Search a specific vector condition: JobID in 1,2,3
condition = VectorSearchSpec(
parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 3]
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert total == 3
assert result
assert len(result) == 3
assert all(r["JobID"] in [1, 2, 3] for r in result)
Expand All @@ -138,7 +151,8 @@ async def test_search_conditions(job_db):
condition = VectorSearchSpec(
parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 5873]
)
result = await job_db.search([], [condition], [])
total, result = await job_db.search([], [condition], [])
assert total == 2
assert result
assert len(result) == 2
assert all(r["JobID"] in [1, 2] for r in result)
Expand All @@ -150,7 +164,8 @@ async def test_search_conditions(job_db):
condition2 = VectorSearchSpec(
parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6]
)
result = await job_db.search([], [condition1, condition2], [])
total, result = await job_db.search([], [condition1, condition2], [])
assert total == 1
assert result
assert len(result) == 1
assert result[0]["JobID"] == 5
Expand All @@ -163,12 +178,13 @@ async def test_search_conditions(job_db):
condition2 = VectorSearchSpec(
parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6]
)
result = await job_db.search([], [condition1, condition2], [])
total, result = await job_db.search([], [condition1, condition2], [])
assert total == 0
assert not result


async def test_search_sorts(job_db):
""""""
"""Test that we can search for jobs in the database and sort the results."""
async with job_db as job_db:
result = await asyncio.gather(
*(
Expand All @@ -187,29 +203,33 @@ async def test_search_sorts(job_db):
async with job_db as job_db:
# Search and sort by JobID in ascending order
sort = SortSpec(parameter="JobID", direction=SortDirection.ASC)
result = await job_db.search([], [], [sort])
total, result = await job_db.search([], [], [sort])
assert total == 100
assert result
for i, r in enumerate(result):
assert r["JobID"] == i + 1

# Search and sort by JobID in descending order
sort = SortSpec(parameter="JobID", direction=SortDirection.DESC)
result = await job_db.search([], [], [sort])
total, result = await job_db.search([], [], [sort])
assert total == 100
assert result
for i, r in enumerate(result):
assert r["JobID"] == 100 - i

# Search and sort by Owner in ascending order
sort = SortSpec(parameter="Owner", direction=SortDirection.ASC)
result = await job_db.search([], [], [sort])
total, result = await job_db.search([], [], [sort])
assert total == 100
assert result
# Assert that owner10 is before owner2 because of the lexicographical order
assert result[2]["Owner"] == "owner10"
assert result[12]["Owner"] == "owner2"

# Search and sort by Owner in descending order
sort = SortSpec(parameter="Owner", direction=SortDirection.DESC)
result = await job_db.search([], [], [sort])
total, result = await job_db.search([], [], [sort])
assert total == 100
assert result
# Assert that owner10 is before owner2 because of the lexicographical order
assert result[97]["Owner"] == "owner10"
Expand All @@ -218,7 +238,8 @@ async def test_search_sorts(job_db):
# Search and sort by OwnerGroup in ascending order and JobID in descending order
sort1 = SortSpec(parameter="OwnerGroup", direction=SortDirection.ASC)
sort2 = SortSpec(parameter="JobID", direction=SortDirection.DESC)
result = await job_db.search([], [], [sort1, sort2])
total, result = await job_db.search([], [], [sort1, sort2])
assert total == 100
assert result
assert result[0]["OwnerGroup"] == "owner_group1"
assert result[0]["JobID"] == 50
Expand All @@ -228,8 +249,62 @@ async def test_search_sorts(job_db):

async def test_search_pagination(job_db):
"""Test that we can search for jobs in the database."""
# TODO: Implement pagination
pass
async with job_db as job_db:
result = await asyncio.gather(
*(
job_db.insert(
f"JDL{i}",
f"owner{i}",
"owner_group1" if i < 50 else "owner_group2",
"New",
"dfdfds",
"lhcb",
)
for i in range(100)
)
)

async with job_db as job_db:
# Search for the first 10 jobs
total, result = await job_db.search([], [], [], per_page=10, page=1)
assert total == 100
assert result
assert len(result) == 10
assert result[0]["JobID"] == 1

# Search for the second 10 jobs
total, result = await job_db.search([], [], [], per_page=10, page=2)
assert total == 100
assert result
assert len(result) == 10
assert result[0]["JobID"] == 11

# Search for the last 10 jobs
total, result = await job_db.search([], [], [], per_page=10, page=10)
assert total == 100
assert result
assert len(result) == 10
assert result[0]["JobID"] == 91

# Search for the second 50 jobs
total, result = await job_db.search([], [], [], per_page=50, page=2)
assert total == 100
assert result
assert len(result) == 50
assert result[0]["JobID"] == 51

# Invalid page number
total, result = await job_db.search([], [], [], per_page=10, page=11)
assert total == 100
assert not result

# Invalid page number
with pytest.raises(InvalidQueryError):
result = await job_db.search([], [], [], per_page=10, page=0)

# Invalid per_page number
with pytest.raises(InvalidQueryError):
result = await job_db.search([], [], [], per_page=0, page=1)


async def test_set_job_command_invalid_job_id(job_db: JobDB):
Expand Down
Loading

0 comments on commit 73647d6

Please sign in to comment.