Skip to content

Commit

Permalink
Merge branch 'network-workflow-api'
Browse files Browse the repository at this point in the history
  • Loading branch information
kba committed Nov 27, 2023
2 parents 35fa77f + d14c787 commit 01f3594
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 15 deletions.
4 changes: 2 additions & 2 deletions ocrd/ocrd/task_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def parse(cls, argstr):
set_json_key_value_overrides(parameters, tokens[1:3])
tokens = tokens[3:]
else:
raise Exception("Failed parsing task description '%s' with tokens remaining: '%s'" % (argstr, tokens))
raise ValueError("Failed parsing task description '%s' with tokens remaining: '%s'" % (argstr, tokens))
return cls(executable, input_file_grps, output_file_grps, parameters)

def __init__(self, executable, input_file_grps, output_file_grps, parameters):
Expand Down Expand Up @@ -108,7 +108,7 @@ def validate_tasks(tasks, workspace, page_id=None, overwrite=False):
# TODO disable output_file_grps checks once CLI parameter 'overwrite' is implemented
# XXX Thu Jan 16 20:14:17 CET 2020 still not sufficiently clever.
# if len(prev_output_file_grps) != len(set(prev_output_file_grps)):
# report.add_error("Output file group specified multiple times: %s" %
# report.add_error("Output file group specified multiple times: %s" %
# [grp for grp, count in Counter(prev_output_file_grps).items() if count >= 2])
prev_output_file_grps += task.output_file_grps
if not report.is_valid:
Expand Down
29 changes: 27 additions & 2 deletions ocrd_network/ocrd_network/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .models import (
DBProcessorJob,
DBWorkflowJob,
DBWorkspace
DBWorkspace,
DBWorkflowScript,
)
from .utils import call_sync

Expand All @@ -31,7 +32,7 @@ async def initiate_database(db_url: str):
client = AsyncIOMotorClient(db_url)
await init_beanie(
database=client.get_default_database(default='ocrd'),
document_models=[DBProcessorJob, DBWorkflowJob, DBWorkspace]
document_models=[DBProcessorJob, DBWorkflowJob, DBWorkspace, DBWorkflowScript]
)


Expand Down Expand Up @@ -199,3 +200,27 @@ async def db_get_processing_jobs(job_ids: List[str]) -> [DBProcessorJob]:
@call_sync
async def sync_db_get_processing_jobs(job_ids: List[str]) -> [DBProcessorJob]:
return await db_get_processing_jobs(job_ids)


async def db_get_workflow_script(workflow_id: str) -> DBWorkflowScript:
workflow = await DBWorkflowScript.find_one(DBWorkflowScript.workflow_id == workflow_id)
if not workflow:
raise ValueError(f'Workflow-script with id "{workflow_id}" not in the DB.')
return workflow


@call_sync
async def sync_db_get_workflow_script(workflow_id: str) -> DBWorkflowScript:
return await db_get_workflow_script(workflow_id)


async def db_find_first_workflow_script_by_content(content_hash: str) -> DBWorkflowScript:
workflow = await DBWorkflowScript.find_one(DBWorkflowScript.content_hash == content_hash)
if not workflow:
raise ValueError(f'Workflow-script with content_hash "{content_hash}" not in the DB.')
return workflow


@call_sync
async def sync_db_find_first_workflow_script_by_content(workflow_id: str) -> DBWorkflowScript:
return await db_get_workflow_script(workflow_id)
2 changes: 1 addition & 1 deletion ocrd_network/ocrd_network/deployment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create_ssh_client(address: str, username: str, password: str = "", keypath:
try:
client.connect(hostname=address, username=username, password=password, key_filename=keypath)
except Exception as error:
raise Exception(f"Error creating SSHClient of host '{address}', reason:") from error
raise Exception(f"Error creating SSHClient of host '{address}', reason: {error}") from error
return client


Expand Down
2 changes: 2 additions & 0 deletions ocrd_network/ocrd_network/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
'DBProcessorJob',
'DBWorkflowJob',
'DBWorkspace',
'DBWorkflowScript',
'PYJobInput',
'PYJobOutput',
'PYOcrdTool',
Expand All @@ -26,3 +27,4 @@
from .messages import PYResultMessage
from .ocrd_tool import PYOcrdTool
from .workspace import DBWorkspace
from .workflow import DBWorkflowScript
9 changes: 9 additions & 0 deletions ocrd_network/ocrd_network/models/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from beanie import Document


class DBWorkflowScript(Document):
""" Model to store a workflow-script in the database
"""
workflow_id: str
content: str
content_hash: str
124 changes: 114 additions & 10 deletions ocrd_network/ocrd_network/processing_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
import requests
import httpx
from os import getpid
from typing import Dict, List
from typing import Dict, List, Union
import uvicorn

from fastapi import (
FastAPI,
status,
Request,
HTTPException,
UploadFile
UploadFile,
File,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse, JSONResponse
from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse

from pika.exceptions import ChannelClosedByBroker
from ocrd.task_sequence import ProcessorTask
Expand All @@ -29,13 +30,16 @@
db_get_workflow_job,
db_get_workspace,
db_update_processing_job,
db_update_workspace
db_update_workspace,
db_get_workflow_script,
db_find_first_workflow_script_by_content
)
from .deployer import Deployer
from .logging import get_processing_server_logging_file_path
from .models import (
DBProcessorJob,
DBWorkflowJob,
DBWorkflowScript,
PYJobInput,
PYJobOutput,
PYResultMessage,
Expand All @@ -61,9 +65,11 @@
download_ocrd_all_tool_json,
generate_created_time,
generate_id,
get_ocrd_workspace_physical_pages
get_ocrd_workspace_physical_pages,
validate_workflow,
)
from urllib.parse import urljoin
from hashlib import md5


class ProcessingServer(FastAPI):
Expand Down Expand Up @@ -195,7 +201,7 @@ def __init__(self, config_path: str, host: str, port: int) -> None:
)

self.router.add_api_route(
path='/workflow',
path='/workflow/run',
endpoint=self.run_workflow,
methods=['POST'],
tags=['workflow', 'processing'],
Expand All @@ -209,14 +215,39 @@ def __init__(self, config_path: str, host: str, port: int) -> None:
)

self.router.add_api_route(
path='/workflow/{workflow_job_id}',
path='/workflow/job/{workflow_job_id}',
endpoint=self.get_workflow_info,
methods=['GET'],
tags=['workflow', 'processing'],
status_code=status.HTTP_200_OK,
summary='Get information about a workflow run',
)

self.router.add_api_route(
path='/workflow',
endpoint=self.upload_workflow,
methods=['POST'],
tags=['workflow'],
status_code=status.HTTP_201_CREATED,
summary='Upload/Register a new workflow script',
)
self.router.add_api_route(
path='/workflow/{workflow_id}',
endpoint=self.replace_workflow,
methods=['PUT'],
tags=['workflow'],
status_code=status.HTTP_200_OK,
summary='Update/Replace a workflow script',
)
self.router.add_api_route(
path='/workflow/{workflow_id}',
endpoint=self.download_workflow,
methods=['GET'],
tags=['workflow'],
status_code=status.HTTP_200_OK,
summary='Download a workflow script',
)

@self.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ')
Expand Down Expand Up @@ -706,8 +737,9 @@ async def task_sequence_to_processing_jobs(

async def run_workflow(
self,
workflow: UploadFile,
mets_path: str,
workflow: Union[UploadFile, None] = File(None),
workflow_id: str = None,
agent_type: str = 'worker',
page_id: str = None,
page_wise: bool = False,
Expand All @@ -721,11 +753,23 @@ async def run_workflow(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"Mets file not existing: {mets_path}")

workflow = (await workflow.read()).decode("utf-8")
if not workflow:
if not workflow_id:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Either workflow or workflow_id must be provided")
try:
workflow = await db_get_workflow_script(workflow_id)
except ValueError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"Workflow with id '{workflow_id}' not found")
workflow = workflow.content
else:
workflow = (await workflow.read()).decode("utf-8")

try:
tasks_list = workflow.splitlines()
tasks = [ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()]
except BaseException as e:
except ValueError as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Error parsing tasks: {e}")

Expand Down Expand Up @@ -817,3 +861,63 @@ async def get_workflow_info(self, workflow_job_id) -> Dict:
"page_id": job.page_id,
})
return res

async def upload_workflow(self, workflow: UploadFile) -> Dict:
""" Store a script for a workflow in the database
"""
workflow_id = generate_id()
content = (await workflow.read()).decode("utf-8")
if not validate_workflow(content):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Provided workflow script is invalid")

content_hash = md5(content.encode("utf-8")).hexdigest()
try:
db_workflow_script = await db_find_first_workflow_script_by_content(content_hash)
if db_workflow_script:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="The same workflow"
f"-script exists with id '{db_workflow_script.workflow_id}'")
except ValueError:
pass

db_workflow_script = DBWorkflowScript(
workflow_id=workflow_id,
content=content,
content_hash=content_hash,
)
await db_workflow_script.insert()
return {"workflow_id": workflow_id}

async def replace_workflow(self, workflow_id, workflow: UploadFile) -> str:
""" Update a workflow script file in the database
"""
content = (await workflow.read()).decode("utf-8")
if not validate_workflow(content):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Provided workflow script is invalid")
try:
db_workflow_script = await db_get_workflow_script(workflow_id)
db_workflow_script.content = content
content_hash = md5(content.encode("utf-8")).hexdigest()
db_workflow_script.content_hash = content_hash
except ValueError as e:
self.log.exception(f"Workflow with id '{workflow_id}' not existing, error: {e}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Workflow-script with id '{workflow_id}' not existing"
)
await db_workflow_script.save()
return db_workflow_script.workflow_id

async def download_workflow(self, workflow_id) -> PlainTextResponse:
""" Load workflow-script from the database
"""
try:
workflow = await db_get_workflow_script(workflow_id)
return PlainTextResponse(workflow.content)
except ValueError as e:
self.log.exception(f"Workflow with id '{workflow_id}' not existing, error: {e}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Workflow-script with id '{workflow_id}' not existing"
)
18 changes: 18 additions & 0 deletions ocrd_network/ocrd_network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ocrd import Resolver, Workspace
from ocrd_validators import ProcessingServerConfigValidator
from .rabbitmq_utils import OcrdResultMessage
from ocrd.task_sequence import ProcessorTask


# Based on: https://gist.github.com/phizaz/20c36c6734878c6ec053245a477572ec
Expand Down Expand Up @@ -147,3 +148,20 @@ def stop_mets_server(mets_server_url: str) -> bool:
if response.status_code == 200:
return True
return False


def validate_workflow(workflow: str, logger=None) -> bool:
""" Check that workflow is not empty and parseable to a lists of ProcessorTask
"""
if not workflow.strip():
if logger:
logger.info("Workflow is invalid (empty string)")
return False
try:
tasks_list = workflow.splitlines()
[ProcessorTask.parse(task_str) for task_str in tasks_list if task_str.strip()]
except ValueError as e:
if logger:
logger.info(f"Workflow is invalid, parsing to ProcessorTasks failed: {e}")
return False
return True

0 comments on commit 01f3594

Please sign in to comment.