Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

refactor agent_events handler #261

Merged
merged 22 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 145 additions & 177 deletions src/api-service/__app__/agent_events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# Licensed under the MIT License.

import logging
from typing import Optional, cast
from typing import Optional, cast, Union
from uuid import UUID

import azure.functions as func
from onefuzztypes.enums import (
ErrorCode,
NodeState,
Expand All @@ -18,44 +17,33 @@
from onefuzztypes.models import (
Error,
NodeDoneEventData,
NodeEvent,
NodeEventEnvelope,
NodeSettingUpEventData,
NodeStateUpdate,
WorkerEvent,
WorkerDoneEvent,
WorkerRunningEvent,
)
from onefuzztypes.responses import BoolResult

from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.pools import Node, NodeTasks
from ..onefuzzlib.request import RequestException, not_ok, ok, parse_request
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.main import Task

ERROR_CONTEXT = "node event"


def get_task_checked(task_id: UUID) -> Task:
task = Task.get_by_task_id(task_id)
if isinstance(task, Error):
raise RequestException(task)
return task


def get_node_checked(machine_id: UUID) -> Node:
def get_node(machine_id: UUID) -> Union[Node, Error]:
node = Node.get_by_machine_id(machine_id)
if not node:
err = Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
raise RequestException(err)
return Error(code=ErrorCode.INVALID_NODE, errors=["unable to find node"])
return node


def on_state_update(
machine_id: UUID,
state_update: NodeStateUpdate,
) -> None:
) -> Optional[Error]:
state = state_update.state
node = get_node_checked(machine_id)
node = get_node(machine_id)
if not isinstance(node, Node):
return node

if state == NodeState.free:
if node.reimage_requested or node.delete_requested:
Expand All @@ -70,184 +58,164 @@ def on_state_update(

if state == NodeState.init:
if node.delete_requested:
ranweiler marked this conversation as resolved.
Show resolved Hide resolved
logging.info("stopping node (init and delete_requested): %s", machine_id)
node.stop()
return
node.reimage_requested = False
node.state = state
node.save()
elif node.state not in NodeState.ready_for_reset():
if node.state != state:
node.state = state
node.save()
return

logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
node.state = state
node.save()

if state == NodeState.setting_up:
# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
setting_up_data = cast(
Optional[NodeSettingUpEventData],
state_update.data,
)

if state == NodeState.setting_up:
# Model-validated.
if setting_up_data:
if not setting_up_data.tasks:
raise Exception("setup without tasks. machine_id: %s", machine_id)

for task_id in setting_up_data.tasks:
task = Task.get_by_task_id(task_id)
if not isinstance(task, Task):
return task

# The task state may be `running` if it has `vm_count` > 1, and
# another node is concurrently executing the task. If so, leave
# the state as-is, to represent the max progress made.
#
# This field will be required in the future.
# For now, it is optional for back compat.
setting_up_data = cast(
Optional[NodeSettingUpEventData],
state_update.data,
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state != TaskState.running:
task.state = TaskState.setting_up

task.on_start()
task.save()

# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
node_task = NodeTasks(
machine_id=machine_id,
task_id=task_id,
state=NodeTaskState.setting_up,
)
node_task.save()

elif state == NodeState.done:
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()

# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data)
if done_data:
# TODO: do something with this done data
if done_data.error:
logging.error(
"node 'done' with error: machine_id:%s, data:%s",
machine_id,
done_data,
)

if setting_up_data:
for task_id in setting_up_data.tasks:
task = get_task_checked(task_id)

# The task state may be `running` if it has `vm_count` > 1, and
# another node is concurrently executing the task. If so, leave
# the state as-is, to represent the max progress made.
#
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state != TaskState.running:
task.state = TaskState.setting_up

task.on_start()
task.save()

# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
node_task = NodeTasks(
machine_id=machine_id,
task_id=task_id,
state=NodeTaskState.setting_up,
)
node_task.save()

elif state == NodeState.done:
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()

# Model-validated.
#
# This field will be required in the future.
# For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data)
if done_data:
# TODO: do something with this done data
if done_data.error:
logging.error(
"node 'done' with error: machine_id:%s, data:%s",
machine_id,
done_data,
)
else:
logging.debug("No change in Node state")
else:
logging.info("ignoring state updates from the node: %s: %s", machine_id, state)

def on_worker_event_running(
machine_id: UUID, event: WorkerRunningEvent
) -> Optional[Error]:
task = Task.get_by_task_id(event.task_id)
if not isinstance(task, Task):
return task

def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None:
if event.running:
task_id = event.running.task_id
elif event.done:
task_id = event.done.task_id
else:
raise NotImplementedError
node = get_node(machine_id)
if not isinstance(node, Node):
return node

task = get_task_checked(task_id)
node = get_node_checked(machine_id)
node_task = NodeTasks(
machine_id=machine_id, task_id=task_id, state=NodeTaskState.running
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
)
node_task.save()

if event.running:
if task.state not in TaskState.shutting_down():
task.state = TaskState.running
if node.state not in NodeState.ready_for_reset():
node.state = NodeState.busy
node.save()
node_task.save()
if task.state not in TaskState.shutting_down():
task.state = TaskState.running
task.save()

# Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents)
task.on_start()
elif event.done:
node_task.delete()

exit_status = event.done.exit_status
if not exit_status.success:
logging.error("task failed. status:%s", exit_status)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % exit_status,
event.done.stdout[-4096:],
event.done.stderr[-4096:],
],
)
)
if task.config.debug and (
TaskDebugFlag.keep_node_on_failure in task.config.debug
or TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()

else:
task.mark_stopping()
if (
task.config.debug
and TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()

node.to_reimage(done=True)
else:
err = Error(
code=ErrorCode.INVALID_REQUEST,
errors=["invalid worker event type"],
)
raise RequestException(err)

task.save()
# Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents)
task.on_start()

task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event)
task_event.save()

def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Optional[Error]:
task = Task.get_by_task_id(event.task_id)
if not isinstance(task, Task):
return task

def post(req: func.HttpRequest) -> func.HttpResponse:
envelope = parse_request(NodeEventEnvelope, req)
if isinstance(envelope, Error):
return not_ok(envelope, context=ERROR_CONTEXT)
node = get_node(machine_id)
if not isinstance(node, Node):
return node

logging.info(
"node event: machine_id: %s event: %s",
envelope.machine_id,
envelope.event,
node_task = NodeTasks(
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
)
node_task.delete()

exit_status = event.done.exit_status
if not exit_status.success:
logging.error(
"task failed. %s:%s status:%s", task.job_id, task.task_id, exit_status
)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % exit_status,
event.done.stdout[-4096:],
event.done.stderr[-4096:],
],
)
)

if task.config.debug and (
TaskDebugFlag.keep_node_on_failure in task.config.debug
or TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()

if isinstance(envelope.event, NodeEvent):
event = envelope.event
elif isinstance(envelope.event, NodeStateUpdate):
event = NodeEvent(state_update=envelope.event)
elif isinstance(envelope.event, WorkerEvent):
event = NodeEvent(worker_event=envelope.event)
else:
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
return not_ok(err, context=ERROR_CONTEXT)

if event.state_update:
on_state_update(envelope.machine_id, event.state_update)
return ok(BoolResult(result=True))
elif event.worker_event:
on_worker_event(envelope.machine_id, event.worker_event)
return ok(BoolResult(result=True))
else:
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
return not_ok(err, context=ERROR_CONTEXT)
logging.error(
"task done. %s:%s status:%s", task.job_id, task.task_id, exit_status
)
task.mark_stopping()
if (
task.config.debug
and TaskDebugFlag.keep_node_on_completion in task.config.debug
):
node.debug_keep_node = True
node.save()

node.to_reimage(done=True)
task_event = TaskEvent(
task_id=task.task_id, machine_id=machine_id, event_data=event
)
task_event.save()
return None

def main(req: func.HttpRequest) -> func.HttpResponse:
try:
if req.method == "POST":
m = post
else:
raise Exception("invalid method")

return verify_token(req, m)
except RequestException as r:
return not_ok(r.error, context=ERROR_CONTEXT)
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> Optional[Error]:
if event.running:
return on_worker_event_running(machine_id, event.running)
elif event.done:
return on_worker_event_done(machine_id, event.done)
else:
raise NotImplementedError
Loading