Skip to content

Commit

Permalink
refactor: QEMUBaseStrategy and internet state
Browse files Browse the repository at this point in the history
  • Loading branch information
rpoisel committed Jan 3, 2025
1 parent 584dbd1 commit 3f7a34f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 176 deletions.
4 changes: 3 additions & 1 deletion util/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base_qemudriver import BaseQEMUDriver
from .custom_qemudriver import CustomQEMUDriver
from .params import QEMUParams
from .stateful_qemudriver import StatefulQEMUDriver

__all__ = [
"QEMUParams",
"BaseQEMUDriver",
"CustomQEMUDriver",
"QEMUParams",
"StatefulQEMUDriver",
]
97 changes: 7 additions & 90 deletions util/strategy/qemu_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,35 @@

import attr
import httpx
from driver import CustomQEMUDriver, QEMUParams
from func import retry_exc
from labgrid import step, target_factory
from labgrid.driver import ShellDriver, SSHDriver
from labgrid.driver.exception import ExecutionError
from labgrid.step import Step
from labgrid.strategy import Strategy, StrategyError
from labgrid.util import get_free_port
from openwrt import enable_dhcp

from .status import Status
from .qemu_strategy import QEMUBaseStrategy


@target_factory.reg_driver
@attr.s(eq=False)
class QEMUNetworkStrategy(Strategy):
class QEMUNetworkStrategy(QEMUBaseStrategy):
bindings = {
"qemu": "CustomQEMUDriver",
"shell": "ShellDriver",
"ssh": "SSHDriver",
"params": "QEMUParams",
}

status: Status = attr.ib(default=Status.unknown)
qemu: CustomQEMUDriver | None = None
shell: ShellDriver | None = None
ssh: SSHDriver | None = None
params: QEMUParams | None = None

def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
assert self.ssh
assert self.params

self._download_image() # keep .gz image
if self.params.overwrite:
logging.info(f"Overwriting image {self.disk_path}")
self._extract_image() # overwrite image if existing

self.__port_forward = None
self.__remote_port = self.ssh.networkservice.port
def on(self) -> None:
self.qemu.on() # type: ignore

def off(self) -> None:
self.qemu.off() # type: ignore

@property
def disk_url(self) -> str:
Expand Down Expand Up @@ -102,74 +90,3 @@ def _extract_image(self) -> None:
gunzip_process.stdin.close()

gunzip_process.wait()

@step(result=True)
def get_remote_address(self) -> str:
return str(self.shell.get_ip_addresses()[0].ip)

@step()
def update_network_service(self) -> None:
assert self.qemu

new_address: str = retry_exc(self.get_remote_address, ExecutionError, "getting the remote address", timeout=20)
networkservice = self.ssh.networkservice

if networkservice.address != new_address:
self.target.deactivate(self.ssh)

if self.__port_forward is not None:
self.qemu.remove_port_forward(*self.__port_forward)

local_port = get_free_port()
local_address = "127.0.0.1"

self.qemu.add_port_forward(
"tcp",
local_address,
local_port,
new_address,
self.__remote_port,
)
self.__port_forward = ("tcp", local_address, local_port)

networkservice.address = local_address
networkservice.port = local_port

@step(args=["state"])
def transition(self, state: Status | str, *, step: Step) -> None:
if not isinstance(state, Status):
state = Status[state]

if state == Status.unknown:
raise StrategyError(f"can not transition to {state}")

elif self.status == state:
step.skip("nothing to do")
return

if state == Status.off:
assert self.target
assert self.qemu

self.target.deactivate(self.qemu)
self.qemu.off()

elif state == Status.shell:
assert self.target
assert self.qemu

# check if target is running
self.qemu.on()
self.target.activate(self.qemu)
self.target.activate(self.shell)

assert self.shell

elif state == Status.ssh:
self.transition(Status.shell)

assert self.shell
enable_dhcp(self.shell)
self.update_network_service()

self.status = state
91 changes: 7 additions & 84 deletions util/strategy/qemu_stateful.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,20 @@
import attr
from driver import StatefulQEMUDriver
from func import retry_exc
from labgrid import step, target_factory
from labgrid.driver import ShellDriver, SSHDriver
from labgrid.driver.exception import ExecutionError
from labgrid.step import Step
from labgrid.strategy import Strategy, StrategyError
from labgrid.util import get_free_port
from openwrt import enable_dhcp
from labgrid import target_factory

from .status import Status
from .qemu_strategy import QEMUBaseStrategy


@target_factory.reg_driver
@attr.s(eq=False)
class QEMUStatefulStrategy(Strategy):
class QEMUStatefulStrategy(QEMUBaseStrategy):
bindings = {
"qemu": "StatefulQEMUDriver",
"shell": "ShellDriver",
"ssh": "SSHDriver",
}

status: Status = attr.ib(default=Status.unknown)
qemu: StatefulQEMUDriver | None = None
shell: ShellDriver | None = None
ssh: SSHDriver | None = None
def on(self) -> None:
pass

def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
self.__port_forward = None
self.__remote_port = self.ssh.networkservice.port

@step(result=True)
def get_remote_address(self) -> str:
return str(self.shell.get_ip_addresses()[0].ip)

def update_network_service(self) -> None:
assert self.target
assert self.qemu

new_address: str = retry_exc(self.get_remote_address, ExecutionError, "getting the remote address", timeout=20)
networkservice = self.ssh.networkservice # type: ignore

if networkservice.address != new_address:
self.target.deactivate(self.ssh)

if self.__port_forward is not None:
self.qemu.remove_port_forward(*self.__port_forward)

local_port = get_free_port()
local_address = "127.0.0.1"

self.qemu.add_port_forward(
"tcp",
local_address,
local_port,
new_address,
self.__remote_port,
)
self.__port_forward = ("tcp", local_address, local_port)

networkservice.address = local_address
networkservice.port = local_port

@step(args=["state"])
def transition(self, state: Status | str, *, step: Step) -> None:
if not isinstance(state, Status):
state = Status[state]

if state == Status.unknown:
raise StrategyError(f"can not transition to {state}")

elif self.status == state:
step.skip("nothing to do")
return

elif state == Status.shell:
assert self.target
assert self.qemu

# check if target is running
self.target.activate(self.qemu)
self.target.activate(self.shell)

assert self.shell

elif state == Status.ssh:
self.transition(Status.shell)

assert self.shell
enable_dhcp(self.shell)
self.update_network_service()
else:
raise StrategyError(f"no transition found from {self.status} to {status}")

self.status = state
def off(self) -> None:
pass
114 changes: 114 additions & 0 deletions util/strategy/qemu_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from abc import ABC, abstractmethod

import attr
from driver import BaseQEMUDriver, QEMUParams
from func import retry_exc
from labgrid import step
from labgrid.driver import ShellDriver, SSHDriver
from labgrid.driver.exception import ExecutionError
from labgrid.strategy import Strategy, StrategyError
from labgrid.util import get_free_port
from openwrt import enable_dhcp

from .status import Status


class QEMUBaseStrategy(ABC, Strategy):
status: Status = attr.ib(default=Status.unknown)
qemu: BaseQEMUDriver | None = None
shell: ShellDriver | None = None
ssh: SSHDriver | None = None
params: QEMUParams | None = None

def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
self._port_forward: tuple[str, str, int] | None = None
self._remote_port: int = self.ssh.networkservice.port

@abstractmethod
def on(self) -> None:
raise NotImplementedError()

@abstractmethod
def off(self) -> None:
raise NotImplementedError()

@step(result=True)
def get_remote_address(self) -> str:
assert self.shell

return str(self.shell.get_ip_addresses()[0].ip)

@step()
def update_network_service(self) -> None:
assert self.target
assert self.qemu
assert self.ssh
assert self._remote_port

new_address: str = retry_exc(self.get_remote_address, ExecutionError, "getting the remote address", timeout=20)
networkservice = self.ssh.networkservice

if networkservice.address != new_address:
self.target.deactivate(self.ssh)

if self._port_forward is not None:
self.qemu.remove_port_forward(*self._port_forward)

local_port = get_free_port()
local_address = "127.0.0.1"

self.qemu.add_port_forward(
"tcp",
local_address,
local_port,
new_address,
self._remote_port,
)
self._port_forward = ("tcp", local_address, local_port)

networkservice.address = local_address
networkservice.port = local_port

@step(args=["status"])
def transition(self, status: Status | str) -> None:
if not isinstance(status, Status):
status = Status[status]

if status == Status.unknown:
raise StrategyError(f"can not transition to {status}") # type: ignore

elif self.status == status:
step.skip("nothing to do") # type: ignore
return

if status == Status.off:
assert self.target
assert self.qemu

self.target.deactivate(self.qemu)
self.off()

elif status == Status.shell:
assert self.target
assert self.qemu

self.on()
self.target.activate(self.qemu)
self.target.activate(self.shell)

assert self.shell

elif status == Status.internet:
self.transition(Status.shell)

assert self.shell
enable_dhcp(self.shell)

elif status == Status.ssh:
self.transition(Status.internet)

assert self.shell
self.update_network_service()

self.status = status
3 changes: 2 additions & 1 deletion util/strategy/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ class Status(enum.Enum):
unknown = 0
off = 1
shell = 2
ssh = 3
internet = 3
ssh = 4

0 comments on commit 3f7a34f

Please sign in to comment.