diff --git a/src/api-service/__app__/agent_events/__init__.py b/src/api-service/__app__/agent_events/__init__.py index 4c1ad8a699..7a6f6a4d2b 100644 --- a/src/api-service/__app__/agent_events/__init__.py +++ b/src/api-service/__app__/agent_events/__init__.py @@ -4,213 +4,56 @@ # Licensed under the MIT License. import logging -from typing import Optional, cast -from uuid import UUID import azure.functions as func -from onefuzztypes.enums import ( - ErrorCode, - NodeState, - NodeTaskState, - TaskDebugFlag, - TaskState, -) from onefuzztypes.models import ( Error, - NodeDoneEventData, NodeEvent, NodeEventEnvelope, - NodeSettingUpEventData, NodeStateUpdate, + Result, WorkerEvent, ) from onefuzztypes.responses import BoolResult from ..onefuzzlib.agent_authorization import verify_token -from ..onefuzzlib.pools import Node, NodeTasks -from ..onefuzzlib.request import RequestException, not_ok, ok, parse_request -from ..onefuzzlib.task_event import TaskEvent -from ..onefuzzlib.tasks.main import Task - -ERROR_CONTEXT = "node event" - - -def get_task_checked(task_id: UUID) -> Task: - task = Task.get_by_task_id(task_id) - if isinstance(task, Error): - raise RequestException(task) - return task - - -def get_node_checked(machine_id: UUID) -> Node: - node = Node.get_by_machine_id(machine_id) - if not node: - err = Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"]) - raise RequestException(err) - return node - - -def on_state_update( - machine_id: UUID, - state_update: NodeStateUpdate, -) -> None: - state = state_update.state - node = get_node_checked(machine_id) - - if state == NodeState.free: - if node.reimage_requested or node.delete_requested: - logging.info("stopping free node with reset flags: %s", node.machine_id) - node.stop() - return - - if node.could_shrink_scaleset(): - logging.info("stopping free node to resize scaleset: %s", node.machine_id) - node.set_halt() - return - - if state == NodeState.init: - if node.delete_requested: - node.stop() - return - node.reimage_requested = False - node.save() - elif node.state not in NodeState.ready_for_reset(): - if node.state != state: - node.state = state - node.save() - - if state == NodeState.setting_up: - # Model-validated. - # - # This field will be required in the future. - # For now, it is optional for back compat. - setting_up_data = cast( - Optional[NodeSettingUpEventData], - state_update.data, - ) - - if setting_up_data: - for task_id in setting_up_data.tasks: - task = get_task_checked(task_id) - - # The task state may be `running` if it has `vm_count` > 1, and - # another node is concurrently executing the task. If so, leave - # the state as-is, to represent the max progress made. - # - # Other states we would want to preserve are excluded by the - # outermost conditional check. - if task.state != TaskState.running: - task.state = TaskState.setting_up - - task.on_start() - task.save() - - # Note: we set the node task state to `setting_up`, even though - # the task itself may be `running`. - node_task = NodeTasks( - machine_id=machine_id, - task_id=task_id, - state=NodeTaskState.setting_up, - ) - node_task.save() - - elif state == NodeState.done: - # if tasks are running on the node when it reports as Done - # those are stopped early - node.mark_tasks_stopped_early() - - # Model-validated. - # - # This field will be required in the future. - # For now, it is optional for back compat. - done_data = cast(Optional[NodeDoneEventData], state_update.data) - if done_data: - # TODO: do something with this done data - if done_data.error: - logging.error( - "node 'done' with error: machine_id:%s, data:%s", - machine_id, - done_data, - ) - else: - logging.debug("No change in Node state") - else: - logging.info("ignoring state updates from the node: %s: %s", machine_id, state) - - -def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None: - if event.running: - task_id = event.running.task_id - elif event.done: - task_id = event.done.task_id - else: - raise NotImplementedError - - task = get_task_checked(task_id) - node = get_node_checked(machine_id) - node_task = NodeTasks( - machine_id=machine_id, task_id=task_id, state=NodeTaskState.running +from ..onefuzzlib.agent_events import on_state_update, on_worker_event +from ..onefuzzlib.request import not_ok, ok, parse_request + + +def process(envelope: NodeEventEnvelope) -> Result[None]: + logging.info( + "node event: machine_id: %s event: %s", + envelope.machine_id, + envelope.event, ) - if event.running: - if task.state not in TaskState.shutting_down(): - task.state = TaskState.running - if node.state not in NodeState.ready_for_reset(): - node.state = NodeState.busy - node.save() - node_task.save() - - # Start the clock for the task if it wasn't started already - # (as happens in 1.0.0 agents) - task.on_start() - elif event.done: - exit_status = event.done.exit_status - if not exit_status.success: - logging.error("task failed. status:%s", exit_status) - task.mark_failed( - Error( - code=ErrorCode.TASK_FAILED, - errors=[ - "task failed. exit_status:%s" % exit_status, - event.done.stdout[-4096:], - event.done.stderr[-4096:], - ], - ) - ) - if task.config.debug and ( - TaskDebugFlag.keep_node_on_failure in task.config.debug - or TaskDebugFlag.keep_node_on_completion in task.config.debug - ): - node.debug_keep_node = True - node.save() - - else: - task.mark_stopping() - if ( - task.config.debug - and TaskDebugFlag.keep_node_on_completion in task.config.debug - ): - node.debug_keep_node = True - node.save() - - node.to_reimage(done=True) - else: - err = Error( - code=ErrorCode.INVALID_REQUEST, - errors=["invalid worker event type"], - ) - raise RequestException(err) + if isinstance(envelope.event, NodeStateUpdate): + return on_state_update(envelope.machine_id, envelope.event) - task.save() + if isinstance(envelope.event, WorkerEvent): + return on_worker_event(envelope.machine_id, envelope.event) - task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event) - task_event.save() + if isinstance(envelope.event, NodeEvent): + if envelope.event.state_update: + result = on_state_update(envelope.machine_id, envelope.event.state_update) + if result is not None: + return result + + if envelope.event.worker_event: + result = on_worker_event(envelope.machine_id, envelope.event.worker_event) + if result is not None: + return result + + return None + + raise NotImplementedError("invalid node event: %s" % envelope) def post(req: func.HttpRequest) -> func.HttpResponse: envelope = parse_request(NodeEventEnvelope, req) if isinstance(envelope, Error): - return not_ok(envelope, context=ERROR_CONTEXT) + return not_ok(envelope, context="node event") logging.info( "node event: machine_id: %s event: %s", @@ -218,34 +61,15 @@ def post(req: func.HttpRequest) -> func.HttpResponse: envelope.event.json(exclude_none=True), ) - if isinstance(envelope.event, NodeEvent): - event = envelope.event - elif isinstance(envelope.event, NodeStateUpdate): - event = NodeEvent(state_update=envelope.event) - elif isinstance(envelope.event, WorkerEvent): - event = NodeEvent(worker_event=envelope.event) - else: - err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"]) - return not_ok(err, context=ERROR_CONTEXT) - - if event.state_update: - on_state_update(envelope.machine_id, event.state_update) - return ok(BoolResult(result=True)) - elif event.worker_event: - on_worker_event(envelope.machine_id, event.worker_event) - return ok(BoolResult(result=True)) - else: - err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"]) - return not_ok(err, context=ERROR_CONTEXT) + result = process(envelope) + if isinstance(result, Error): + logging.error( + "unable to process agent event. envelope:%s error:%s", envelope, result + ) + return not_ok(result, context="node event") + + return ok(BoolResult(result=True)) def main(req: func.HttpRequest) -> func.HttpResponse: - try: - if req.method == "POST": - m = post - else: - raise Exception("invalid method") - - return verify_token(req, m) - except RequestException as r: - return not_ok(r.error, context=ERROR_CONTEXT) + return verify_token(req, post) diff --git a/src/api-service/__app__/agent_events/function.json b/src/api-service/__app__/agent_events/function.json index aa41abc8b8..95958727dd 100644 --- a/src/api-service/__app__/agent_events/function.json +++ b/src/api-service/__app__/agent_events/function.json @@ -7,16 +7,14 @@ "direction": "in", "name": "req", "methods": [ - "get", - "post", - "delete" + "post" ], "route": "agents/events" - }, - { + }, + { "type": "http", "direction": "out", "name": "$return" - } + } ] -} +} \ No newline at end of file diff --git a/src/api-service/__app__/onefuzzlib/agent_events.py b/src/api-service/__app__/onefuzzlib/agent_events.py new file mode 100644 index 0000000000..5ac20d6d45 --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/agent_events.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +from typing import Optional, cast +from uuid import UUID + +from onefuzztypes.enums import ( + ErrorCode, + NodeState, + NodeTaskState, + TaskDebugFlag, + TaskState, +) +from onefuzztypes.models import ( + Error, + NodeDoneEventData, + NodeSettingUpEventData, + NodeStateUpdate, + Result, + WorkerDoneEvent, + WorkerEvent, + WorkerRunningEvent, +) + +from ..onefuzzlib.pools import Node, NodeTasks +from ..onefuzzlib.task_event import TaskEvent +from ..onefuzzlib.tasks.main import Task + + +def get_node(machine_id: UUID) -> Result[Node]: + node = Node.get_by_machine_id(machine_id) + if not node: + return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"]) + return node + + +def on_state_update( + machine_id: UUID, + state_update: NodeStateUpdate, +) -> Result[None]: + state = state_update.state + node = get_node(machine_id) + if isinstance(node, Error): + return node + + if state == NodeState.free: + if node.reimage_requested or node.delete_requested: + logging.info("stopping free node with reset flags: %s", node.machine_id) + node.stop() + return None + + if node.could_shrink_scaleset(): + logging.info("stopping free node to resize scaleset: %s", node.machine_id) + node.set_halt() + return None + + if state == NodeState.init: + if node.delete_requested: + logging.info("stopping node (init and delete_requested): %s", machine_id) + node.stop() + return None + + # not checking reimage_requested, as nodes only send 'init' state once. If + # they send 'init' with reimage_requested, it's because the node was reimaged + # successfully. + node.reimage_requested = False + node.state = state + node.save() + return None + + logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state) + node.state = state + node.save() + + if state == NodeState.free: + logging.info("node now available for work: %s", machine_id) + elif state == NodeState.setting_up: + # Model-validated. + # + # This field will be required in the future. + # For now, it is optional for back compat. + setting_up_data = cast( + Optional[NodeSettingUpEventData], + state_update.data, + ) + + if setting_up_data: + if not setting_up_data.tasks: + return Error( + code=ErrorCode.INVALID_REQUEST, + errors=["setup without tasks. machine_id: %s", str(machine_id)], + ) + + for task_id in setting_up_data.tasks: + task = Task.get_by_task_id(task_id) + if isinstance(task, Error): + return task + + logging.info( + "node starting task. machine_id: %s job_id: %s task_id: %s", + machine_id, + task.job_id, + task.task_id, + ) + + # The task state may be `running` if it has `vm_count` > 1, and + # another node is concurrently executing the task. If so, leave + # the state as-is, to represent the max progress made. + # + # Other states we would want to preserve are excluded by the + # outermost conditional check. + if task.state not in [TaskState.running, TaskState.setting_up]: + task.state = TaskState.setting_up + task.save() + task.on_start() + + # Note: we set the node task state to `setting_up`, even though + # the task itself may be `running`. + node_task = NodeTasks( + machine_id=machine_id, + task_id=task_id, + state=NodeTaskState.setting_up, + ) + node_task.save() + + elif state == NodeState.done: + # if tasks are running on the node when it reports as Done + # those are stopped early + node.mark_tasks_stopped_early() + + # Model-validated. + # + # This field will be required in the future. + # For now, it is optional for back compat. + done_data = cast(Optional[NodeDoneEventData], state_update.data) + if done_data: + # TODO: do something with this done data + if done_data.error: + logging.error( + "node 'done' with error: machine_id:%s, data:%s", + machine_id, + done_data, + ) + return None + + +def on_worker_event_running( + machine_id: UUID, event: WorkerRunningEvent +) -> Result[None]: + task = Task.get_by_task_id(event.task_id) + if isinstance(task, Error): + return task + + node = get_node(machine_id) + if isinstance(node, Error): + return node + + if node.state not in NodeState.ready_for_reset(): + node.state = NodeState.busy + node.save() + + node_task = NodeTasks( + machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running + ) + node_task.save() + + if task.state in TaskState.shutting_down(): + logging.info( + "ignoring task start from node. machine_id:%s %s:%s (state: %s)", + machine_id, + task.job_id, + task.task_id, + task.state, + ) + return None + + logging.info( + "task started on node. machine_id:%s %s:%s", + machine_id, + task.job_id, + task.task_id, + ) + task.state = TaskState.running + task.save() + + # Start the clock for the task if it wasn't started already + # (as happens in 1.0.0 agents) + task.on_start() + + return None + + +def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[None]: + task = Task.get_by_task_id(event.task_id) + if isinstance(task, Error): + return task + + node = get_node(machine_id) + if isinstance(node, Error): + return node + + if event.exit_status.success: + logging.info( + "task done. %s:%s status:%s", task.job_id, task.task_id, event.exit_status + ) + task.mark_stopping() + if ( + task.config.debug + and TaskDebugFlag.keep_node_on_completion in task.config.debug + ): + node.debug_keep_node = True + node.save() + else: + logging.error( + "task failed. %s:%s status:%s", task.job_id, task.task_id, event.exit_status + ) + task.mark_failed( + Error( + code=ErrorCode.TASK_FAILED, + errors=[ + "task failed. exit_status:%s" % event.exit_status, + event.stdout[-4096:], + event.stderr[-4096:], + ], + ) + ) + + if task.config.debug and ( + TaskDebugFlag.keep_node_on_failure in task.config.debug + or TaskDebugFlag.keep_node_on_completion in task.config.debug + ): + node.debug_keep_node = True + node.save() + + node.to_reimage(done=True) + task_event = TaskEvent( + task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event) + ) + task_event.save() + return None + + +def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Result[None]: + if event.running: + return on_worker_event_running(machine_id, event.running) + elif event.done: + return on_worker_event_done(machine_id, event.done) + else: + raise NotImplementedError diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 71431782af..cb8458d40f 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -4,7 +4,7 @@ # Licensed under the MIT License. from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from uuid import UUID, uuid4 from pydantic import BaseModel, Field, root_validator, validator @@ -58,6 +58,10 @@ class Error(BaseModel): errors: List[str] +OkType = TypeVar("OkType") +Result = Union[OkType, Error] + + class FileEntry(BaseModel): container: Container filename: str