Skip to content

Commit

Permalink
Run name safety improvements (closes #30).
Browse files Browse the repository at this point in the history
  • Loading branch information
cgevans committed Apr 5, 2023
1 parent a9ca08a commit e51ea3f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 19 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ SPDX-License-Identifier: AGPL-3.0-only

# Changelog

## Version 0.9.1 (dev)
## Version 0.9.1

- Minor bug fixes and dependency updates (to fix pandas errors).
- Fixes to support Pandas 2.0.
- Ensure that some invalid characters are not used in machine names.
- Check for files with and without spaces on machine when loading a new experiment (in case run was started outside of qslib).
- Parse IOError messages from the machine.

## Version 0.9.0

Expand Down
76 changes: 59 additions & 17 deletions src/qslib/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,25 @@
from .machine import Machine
from .processors import NormRaw, Processor
from .protocol import Protocol, Stage, Step
from .qs_is_protocol import QS_IOError
from .rawquant_compat import _fdc_to_rawdata
from .version import __version__

if TYPE_CHECKING: # pragma: no cover
import matplotlib.pyplot as plt

# Let's just assume all of these are problematic, for now.
INVALID_NAME_RE = re.compile(r"[\[\]{}!/:;@&=+$,?#|\\]")


def _safe_exp_name(name: str) -> str:
m = INVALID_NAME_RE.search(name)

if m:
raise ValueError(f"Invalid characters ({m[0]}) in run name: {name}.")

return name.replace(" ", "_")


TEMPLATE_NAME = "ruo"

Expand Down Expand Up @@ -478,8 +491,8 @@ def _repr_markdown_(self) -> str:

@property
def runtitle_safe(self) -> str:
"""Run name with " " replaced by "_"."""
return self.name.replace(" ", "_")
"""Run name with " " replaced by "_"; raises ValueError if name has other problematic characters."""
return _safe_exp_name(self.name)

@property
def rawdata(self) -> pd.DataFrame:
Expand Down Expand Up @@ -1045,6 +1058,10 @@ def __init__(
else:
self.name = _nowuuid()

# Ensure that name is safe for use as a filename:
# this will raise a ValueError for us if it isn't.
self.runtitle_safe

if protocol is None:
self.protocol = Protocol([Stage([Step(60, 25)])])
else:
Expand Down Expand Up @@ -1249,15 +1266,18 @@ def from_uncollected(
machine = exp._ensure_machine(machine)

with machine.ensured_connection():
crt = name

if move:
raise NotImplementedError

if not crt:
raise ValueError("Nothing is currently running.")

z = machine.read_dir_as_zip(crt, leaf="EXP")
try:
z = machine.read_dir_as_zip(_safe_exp_name(name), leaf="EXP")
except QS_IOError:
try:
z = machine.read_dir_as_zip(name, leaf="EXP")
except QS_IOError:
raise ValueError(
f"Could not find experiment {name} in uncollect runs on {machine}."
)

z.extractall(exp._dir_base)

Expand All @@ -1267,12 +1287,15 @@ def from_uncollected(

@classmethod
def from_machine_storage(cls, machine: MachineReference, name: str) -> Experiment:
"""Create an experiment from the one currently running on a machine.
"""Create an experiment from the machine's storage.
Parameters
----------
machine : Machine
the machine to connect to
the machine to connect to.
name: str
the name of the run to collect.
Returns
-------
Expand All @@ -1284,10 +1307,20 @@ def from_machine_storage(cls, machine: MachineReference, name: str) -> Experimen
machine = exp._ensure_machine(machine)

with machine.ensured_connection():
try:
o = machine.read_file(name + ".eds", context="public_run_complete")
except FileNotFoundError:
o = machine.read_file(name, context="public_run_complete")
o = None
for possible_name in [
_safe_exp_name(name) + ".eds",
_safe_exp_name(name),
name + ".eds",
name,
]:
try:
o = machine.read_file(possible_name, context="public_run_complete")
break
except QS_IOError:
continue
if o is None:
raise FileNotFoundError(f"Could not find {name} on {machine.host}.")

z = zipfile.ZipFile(io.BytesIO(o))

Expand Down Expand Up @@ -1319,13 +1352,22 @@ def from_machine(cls, machine: MachineReference, name: str) -> Experiment:
if isinstance(machine, str):
machine = Machine(machine)

safename = _safe_exp_name(name)

with machine.ensured_connection():
if name == machine.current_run_name:
if machine.current_run_name in [safename, name]:
exp = cls.from_running(machine)
elif name in machine.list_runs_in_storage():
return exp

storage_runs = machine.list_runs_in_storage()
if (name in storage_runs) or (safename in storage_runs):
exp = cls.from_machine_storage(machine, name)
elif name + "/" in machine.list_files("", verbose=False, leaf="EXP"):
return exp

exp_runs = machine.list_files("", verbose=False, leaf="EXP")
if ((name + "/") in exp_runs) or ((safename + "/") in exp_runs):
exp = cls.from_uncollected(machine, name)

else:
raise FileNotFoundError(f"Could not find run {name} on {machine.host}.")
return exp
Expand Down
23 changes: 22 additions & 1 deletion src/qslib/qs_is_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dataclasses import dataclass
from typing import Any, Coroutine, Optional, Protocol, Type

from .scpi_commands import AccessLevel, SCPICommand
from .scpi_commands import AccessLevel, SCPICommand, _arglist

NL_OR_Q = re.compile(rb"(?:\n|<(/?)([\w.]+)[ *]*>)")
Q_ONLY = re.compile(rb"<(/?)([\w.]+)[ *]*>")
Expand Down Expand Up @@ -70,11 +70,31 @@ def parse(command: str, ref_index: str, response: str) -> CommandError:

@dataclass
class UnparsedCommandError(CommandError):
"""The machine has returned an error that we are not familiar with,
and that we haven't parsed."""

command: Optional[str]
ref_index: Optional[str]
response: str


@dataclass
class QS_IOError(CommandError):
command: str
message: str
data: dict[str, str]

@classmethod
def parse(cls, command: str, ref_index: str, message: str) -> QS_IOError:
m = re.match(r"(.*) --> (.*)", message)
if not m:
raise ValueError

data = _arglist.parse_string(m[1])[0].opts

return cls(command, m[2], data)


@dataclass
class InsufficientAccess(CommandError):
command: str
Expand Down Expand Up @@ -145,6 +165,7 @@ def parse(cls, command: str, ref_index: str, message: str) -> NoMatch:
"AccessLevelExceeded": AccessLevelExceeded,
"InvocationError": InvocationError,
"NoMatch": NoMatch,
"IOError": QS_IOError,
}


Expand Down
6 changes: 6 additions & 0 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ def test_fail_plots():

with pytest.raises(ValueError, match="no data available"):
exp.plot_anneal_melt()


@pytest.mark.parametrize("ch", ["/", "!", "}"])
def test_unsafe_names(ch):
with pytest.raises(ValueError, match=r"Invalid characters \(" + ch + r"\)"):
exp = Experiment(name=f"a{ch}b")

0 comments on commit e51ea3f

Please sign in to comment.