Skip to content

Commit

Permalink
q-dev: fix tests and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrbartman committed Oct 14, 2024
1 parent 5f76980 commit 5210963
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 66 deletions.
155 changes: 104 additions & 51 deletions qubesadmin/device_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def qbool(value):


class DeviceSerializer:
"""
Group of method for serialization of device properties.
"""
ALLOWED_CHARS_KEY = set(
string.digits + string.ascii_letters
+ r"!#$%&()*+,-./:;<>?@[\]^_{|}~")
Expand Down Expand Up @@ -197,7 +200,7 @@ def parse_basic_device_properties(
f"Unrecognized device identity '{properties['device_id']}' "
f"expected '{expected_device.device_id}'"
)
expected._device_id = properties.get('device_id', expected_devid)
properties['device_id'] = properties.get('device_id', expected_devid)

properties['port'] = expected

Expand Down Expand Up @@ -225,8 +228,8 @@ def sanitize_str(
"""
Sanitize given untrusted string.
If `replace_char` is not None, ignore `error_message` and replace invalid
characters with the string.
If `replace_char` is not None, ignore `error_message` and replace
invalid characters with the string.
"""
if replace_char is None:
not_allowed_chars = set(untrusted_value) - allowed_chars
Expand All @@ -249,7 +252,7 @@ class Port:
Attributes:
backend_domain (QubesVM): The domain which exposes devices,
e.g.`sys-usb`.
port_id (str): A unique identifier for the port within the backend domain.
port_id (str): A unique (in backend domain) identifier for the port.
devclass (str): The class of the port (e.g., 'usb', 'pci').
"""
def __init__(self, backend_domain, port_id, devclass):
Expand Down Expand Up @@ -284,6 +287,7 @@ def __str__(self):

@property
def backend_name(self) -> str:
# pylint: disable=missing-function-docstring
if self.backend_domain not in (None, "*"):
return self.backend_domain.name
return "*"
Expand All @@ -292,6 +296,9 @@ def backend_name(self) -> str:
def from_qarg(
cls, representation: str, devclass, domains, blind=False
) -> 'Port':
"""
Parse qrexec argument <back_vm>+<port_id> to retrieve Port.
"""
if blind:
get_domain = domains.get_blind
else:
Expand All @@ -302,6 +309,9 @@ def from_qarg(
def from_str(
cls, representation: str, devclass, domains, blind=False
) -> 'Port':
"""
Parse string <back_vm>:<port_id> to retrieve Port.
"""
if blind:
get_domain = domains.get_blind
else:
Expand All @@ -316,6 +326,9 @@ def _parse(
get_domain: Callable,
sep: str
) -> 'Port':
"""
Parse string representation and return instance of Port.
"""
backend_name, port_id = representation.split(sep, 1)
backend = get_domain(backend_name)
return cls(backend_domain=backend, port_id=port_id, devclass=devclass)
Expand Down Expand Up @@ -364,7 +377,7 @@ def __init__(
self.port: Optional[Port] = port
self._device_id = device_id

def clone(self, **kwargs):
def clone(self, **kwargs) -> 'VirtualDevice':
"""
Clone object and substitute attributes with explicitly given.
"""
Expand All @@ -376,45 +389,64 @@ def clone(self, **kwargs):
return self.__class__(**attr)

@property
def port(self):
def port(self) -> Union[Port, str]:
# pylint: disable=missing-function-docstring
return self._port

@port.setter
def port(self, value):
def port(self, value: Union[Port, str, None]):
# pylint: disable=missing-function-docstring
self._port = value if value is not None else '*'

@property
def device_id(self):
def device_id(self) -> str:
# pylint: disable=missing-function-docstring
if self._device_id is not None:
return self._device_id
return '*'

@property
def backend_domain(self):
def is_device_id_set(self) -> bool:
"""
Check if `device_id` is explicitly set.
"""
return self._device_id is not None

@property
def backend_domain(self) -> Union[QubesVM, str]:
# pylint: disable=missing-function-docstring
if self.port != '*' and self.port.backend_domain is not None:
return self.port.backend_domain
return '*'

@property
def backend_name(self):
def backend_name(self) -> str:
"""
Return backend domain name if any or `*`.
"""
if self.port != '*':
return self.port.backend_name
return '*'

@property
def port_id(self):
def port_id(self) -> str:
# pylint: disable=missing-function-docstring
if self.port != '*' and self.port.port_id is not None:
return self.port.port_id
return '*'

@property
def devclass(self):
def devclass(self) -> str:
# pylint: disable=missing-function-docstring
if self.port != '*' and self.port.devclass is not None:
return self.port.devclass
return '*'

@property
def description(self):
def description(self) -> str:
"""
Return human-readable description of the device identity.
"""
if self.device_id == '*':
return 'any device'
return self.device_id
Expand Down Expand Up @@ -451,17 +483,16 @@ def __lt__(self, other):
if self.port != '*' and other.port == '*':
return False
reprs = {self: [self.port], other: [other.port]}
for obj in reprs:
for obj, obj_repr in reprs.items():
if obj.device_id != '*':
reprs[obj].append(obj.device_id)
obj_repr.append(obj.device_id)
return reprs[self] < reprs[other]
elif isinstance(other, Port):
if isinstance(other, Port):
_other = VirtualDevice(other, '*')
return self < _other
else:
raise TypeError(
f"Comparing instances of {type(self)} and '{type(other)}' "
"is not supported")
raise TypeError(
f"Comparing instances of {type(self)} and '{type(other)}' "
"is not supported")

def __repr__(self):
return f"{self.port!r}:{self.device_id}"
Expand All @@ -478,6 +509,9 @@ def from_qarg(
blind=False,
backend=None,
) -> 'VirtualDevice':
"""
Parse qrexec argument <back_vm>+<port_id>:<device_id> to get device info
"""
if backend is None:
if blind:
get_domain = domains.get_blind
Expand All @@ -492,6 +526,9 @@ def from_str(
cls, representation: str, devclass: Optional[str], domains,
blind=False, backend=None
) -> 'VirtualDevice':
"""
Parse string <back_vm>+<port_id>:<device_id> to get device info
"""
if backend is None:
if blind:
get_domain = domains.get_blind
Expand All @@ -510,6 +547,9 @@ def _parse(
backend,
sep: str
) -> 'VirtualDevice':
"""
Parse string representation and return instance of VirtualDevice.
"""
if backend is None:
backend_name, identity = representation.split(sep, 1)
if backend_name != '*':
Expand Down Expand Up @@ -721,14 +761,23 @@ def _load_classes(bus: str):
return result

def matches(self, other: 'DeviceInterface') -> bool:
"""
Check if this `DeviceInterface` (pattern) matches given one.
The matching is done character by character using the string
representation (`repr`) of both objects. A wildcard character (`'*'`)
in the pattern (i.e., `self`) can match any character in the candidate
(i.e., `other`).
The two representations must be of the same length.
"""
pattern = repr(self)
candidate = repr(other)
if len(pattern) != len(candidate):
return False
for p, c in zip(pattern, candidate):
if p == '*':
for patt, cand in zip(pattern, candidate):
if patt == '*':
continue
if p != c:
if patt != cand:
return False
return True

Expand Down Expand Up @@ -929,7 +978,8 @@ def serialize(self) -> bytes:
'parent_devclass', self.parent_device.devclass)

for key, value in self.data.items():
properties += b' ' + DeviceSerializer.pack_property("_" + key, value)
properties += b' ' + DeviceSerializer.pack_property(
"_" + key, value)

return properties

Expand All @@ -952,6 +1002,7 @@ def deserialize(
device = cls._deserialize(rest, device)
# pylint: disable=broad-exception-caught
except Exception as exc:
print(str(exc), file=sys.stderr)
device = UnknownDevice.from_device(device)

return device
Expand Down Expand Up @@ -1026,36 +1077,28 @@ def device_id(self, value):
class UnknownDevice(DeviceInfo):
# pylint: disable=too-few-public-methods
"""Unknown device - for example, exposed by domain not running currently"""

@staticmethod
def from_device(device) -> 'UnknownDevice':
def from_device(device: VirtualDevice) -> 'UnknownDevice':
"""
Return `UnknownDevice` based on any virtual device.
"""
return UnknownDevice(device.port, device_id=device.device_id)


class AssignmentMode(Enum):
"""
Device assignment modes
"""
MANUAL = "manual"
ASK = "ask-to-attach"
AUTO = "auto-attach"
REQUIRED = "required"


class DeviceAssignment:
""" Maps a device to a frontend_domain.
There are 3 flags `attached`, `automatically_attached` and `required`.
The meaning of valid combinations is as follows:
1. (True, False, False) -> domain is running, device is manually attached
and could be manually detach any time.
2. (True, True, False) -> domain is running, device is attached
and could be manually detach any time (see 4.),
but in the future will be auto-attached again.
3. (True, True, True) -> domain is running, device is attached
and couldn't be detached.
4. (False, Ture, False) -> device is assigned to domain, but not attached
because either (i) domain is halted,
device (ii) manually detached or
(iii) attach to different domain.
5. (False, True, True) -> domain is halted, device assigned to domain
and required to start domain.
"""
Maps a device to a frontend_domain.
"""

def __init__(
Expand Down Expand Up @@ -1116,23 +1159,28 @@ def __lt__(self, other):
"is not supported")

@property
def backend_domain(self):
def backend_domain(self) -> QubesVM:
# pylint: disable=missing-function-docstring
return self.virtual_device.backend_domain

@property
def backend_name(self) -> str:
# pylint: disable=missing-function-docstring
return self.virtual_device.backend_name

@property
def port_id(self):
def port_id(self) -> str:
# pylint: disable=missing-function-docstring
return self.virtual_device.port_id

@property
def devclass(self):
def devclass(self) -> str:
# pylint: disable=missing-function-docstring
return self.virtual_device.devclass

@property
def device_id(self):
def device_id(self) -> str:
# pylint: disable=missing-function-docstring
return self.virtual_device.device_id

@property
Expand Down Expand Up @@ -1241,7 +1289,8 @@ def serialize(self) -> bytes:
'frontend_domain', self.frontend_domain.name)

for key, value in self.options.items():
properties += b' ' + DeviceSerializer.pack_property("_" + key, value)
properties += b' ' + DeviceSerializer.pack_property(
"_" + key, value)

return properties

Expand Down Expand Up @@ -1275,22 +1324,26 @@ def _deserialize(

DeviceSerializer.parse_basic_device_properties(
expected_device, properties)

expected_device = expected_device.clone(
device_id=properties['device_id'])
# we do not need port, we need device
del properties['port']
expected_device._device_id = properties.get(
'device_id', expected_device.device_id)
properties.pop('device_id', None)
properties['device'] = expected_device

return cls(**properties)

def matches(self, device: VirtualDevice) -> bool:
"""
Checks if the given device matches the assignment.
"""
if self.devclass != device.devclass:
return False
if self.backend_domain != device.backend_domain:
return False
if self.port_id != '*' and self.port_id != device.port_id:
if self.port_id not in ('*', device.port_id):
return False
if self.device_id != '*' and self.device_id != device.device_id:
if self.device_id not in ('*', device.device_id):
return False
return True
4 changes: 2 additions & 2 deletions qubesadmin/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
Devices can be of different classes (like 'pci', 'usb', etc.). Each device
class is implemented by an extension.
Devices are identified by pair of (backend domain, `port_id`), where `port_id` is
:py:class:`str`.
Devices are identified by pair of (backend domain, `port_id`), where `port_id`
is :py:class:`str`.
"""
import itertools
from typing import Optional, Iterable
Expand Down
Loading

0 comments on commit 5210963

Please sign in to comment.