Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup/separate class for swtpm2 #38

Merged
merged 7 commits into from
Jul 31, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 122 additions & 45 deletions test/harness
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SubpResult():
def __str__(self):
def indent(name, data):
if data is None:
return "{name}: {data}"
return f"{name}:"

if hasattr(data, 'decode'):
data = data.decode('utf-8', errors='ignore')
Expand Down Expand Up @@ -274,9 +274,9 @@ def subp(cmd, capture=True, data=None, rcs=(0,), timeout=None, ksignal=signal.SI
sp.send_signal(ksignal)
thread.join(5)
if thread.is_alive():
print("%s didn't die on %s, sending SIGTERM." %
print("%s didn't die on %s, sending SIGKILL." %
(cmd[0], ksignal))
sp.send_signal(signal.SIGTERM)
sp.send_signal(signal.SIGKILL)
thread.join()

if result.exception is None:
Expand Down Expand Up @@ -458,6 +458,11 @@ def smash(output, stubefi, kernel, initrd, sbat, cmdline):
os.unlink(cmdline_f)


def ensure_dirs(dlist):
for d in dlist:
ensure_dir(d)


def ensure_dir(mdir):
if os.path.isdir(mdir):
return
Expand Down Expand Up @@ -538,6 +543,102 @@ def get_env_boolean(name, default):
raise ValueError(f"environment var {name} had value '{val}'. Expected 'true' or 'false'")


class swtpm2:
socket_timeout = 10
sigterm_timeout = 5

result, sp, runthread, start_time = (None, None, None, None)

def __init__(self, state_d, socket=None, log=None, timeout=None):
self.state_d = state_d
self.socket = socket if socket else path_join(self.state_d, "socket")
self.log = log if log else path_join(self.state_d, "log")
self.timeout = timeout

def cmd(self):
return ["swtpm", "socket", "--tpmstate=dir=" + self.state_d,
"--ctrl=type=unixio,path=" + self.socket,
"--log=level=20,file=" + self.log,
"--tpm2"]

def _communicate(self):
try:
(self.result.out, self.result.err) = self.sp.communicate(self.timeout)
except Exception as e:
self.result.exception = e
else:
self.result.rc = self.sp.returncode
finally:
self.sp = None
self.result.duration = time.time() - self.start_time

def start(self):
ensure_dirs([
os.path.dirname(self.socket),
os.path.dirname(self.log),
self.state_d])

try:
os.remove(self.socket)
except FileNotFoundError:
pass

self.start_time = time.time()
self.result = SubpResult(cmd=self.cmd(), rc=-1)
try:
self.sp = subprocess.Popen(
self.result.cmd,
stdout=subprocess.PIPE, stderr = subprocess.PIPE)
except FileNotFoundError as e:
self.result.duration = time.time() - self.start_time
self.result.exception = e
raise SubpError(self.result)

self.runthread = threading.Thread(target=self._communicate)
self.runthread.start()

# wait for the socket to appear to signal system ready
while True:
if os.path.exists(self.socket):
break
if not self.runthread.is_alive():
raise SubpError(self.result)
if time.time() - self.start_time > self.socket_timeout:
self.stop()
logging.warn(
"swtpm process did not create socket %s within %ds "
"after starting", self.socket, self.socket_timeout)
raise TimeoutError("Timeout waiting for swtpm to create socket")

time.sleep(.1)


def stop(self):
if self.runthread is None:
logging.debug("Called stop, but no runthread")
return
if self.runthread.is_alive():
if self.sp:
self.sp.send_signal(signal.SIGTERM)
self.runthread.join(self.sigterm_timeout)
if self.runthread.is_alive():
logging.debug(
"swtpm process did not exit within %d seconds of SIGTERM"
", sending SIGKILL", self.sigterm_timeout)
if self.sp:
self.sp.send_signal(signal.SIGKILL)
self.runthread.join()

self.runthread = None

def __repr__(self):
return " ".join([
"swtpm(",
"pid = %d" % (-1 if self.sp is None else self.sp.pid),
")",
])


class Runner:
def __init__(self, cliargs, workdir=None):
self.cleanup_workdir = True
Expand Down Expand Up @@ -609,7 +710,7 @@ class Runner:
jobq.task_done()

# Turn-on the worker thread.
for i in range(self.num_threads):
for _ in range(self.num_threads):
threading.Thread(target=worker, daemon=True).start()

for testdata in self.tests:
Expand Down Expand Up @@ -649,7 +750,7 @@ class Runner:
print(f"Starting {name} in {run_d}\n qemu_log: {qemu_log}\n serial: {serial_log}.raw")

try:
self._run(testdata, run_d, qemu_log, serial_log, runlog)
self._run(testdata, run_d, results_d, qemu_log, serial_log, runlog)
except SubpError as e:
if isinstance(e.result.exception, subprocess.TimeoutExpired):
result = "TIMEOUT: " + str(e)
Expand All @@ -666,8 +767,7 @@ class Runner:
res = result.split(":")[0]
print("Finished %s [%.2fs]: %s" % (name, time.time() - start, res))

def _run(self, testdata, run_d, qemu_log, serial_log, log):
name = testdata["name"]
def _run(self, testdata, run_d, results_d, qemu_log, serial_log, log):
esp = path_join(run_d, "esp.raw")

gen_esp(
Expand Down Expand Up @@ -713,26 +813,22 @@ class Runner:
"-vnc", "none",
]

ensure_dir(path_join(run_d, "tpm"))
tpm_cmd = ["swtpm", "socket", "--tpmstate=dir=./tpm",
"--ctrl=type=unixio,path=./tpm/socket",
"--log=level=20,file=./tpm/log",
"--pid=file=./tpm/pid", "--tpm2"]

cmd_inter = cmd_base + ["-serial", "mon:stdio"]

write_file(path_join(run_d, "boot"),
"#!/bin/sh\n" +
" ".join(tpm_cmd) + " &\n" +
" ".join(cmd_inter) + "\n" +
"mkdir -p tpm\n" +
" ".join(swtpm2("./tpm").cmd()) + " &\n" +
" ".join(cmd_base + ["-serial", "mon:stdio"]) + "\n"
"r=$?\n" + "wait\n" + "exit $r\n")

timeout = self.timeout

tpm_thread = threading.Thread(
target=subp, args=(tpm_cmd,), daemon=True,
kwargs={"cwd": run_d, "rcs": None, "capture": False, "timeout": timeout})
tpm_thread.start()
tpm = swtpm2(
state_d=path_join(run_d, "tpm"),
log=path_join(results_d, "tpm.log"),
timeout=self.timeout)

log("starting tpm")
tpm.start()

log("started tpm, starting qemu")
with open(qemu_log, "wb") as qfp:
Expand All @@ -751,30 +847,11 @@ class Runner:
if not rawlog.endswith(b"\n"):
wfp.write(b"\n")

# If the tpm subprocess is still around, give it a SIGTERM
if not tpm_thread.is_alive():
log("TPM thread already exited")
else:
tpm_pidf = path_join(run_d, "tpm", "pid")
tpm_pid = 0
if os.path.exists(tpm_pidf):
pidstr = read_file(tpm_pidf).strip()
try:
tpm_pid = int(pidstr)
except ValueError:
log("WARN: Read non-integer '%s' from tpm pidfile %s" %
(pidstr, tpm_pidf))
else:
log("WARN: TPM Pidfile %s did not exist" % (tpm_pidf))

if tpm_pid != 0:
log("sending SIGTERM to TPM pid %d" % tpm_pid)
try:
os.kill(tpm_pid, signal.SIGTERM)
except ProcessLookupError:
log("TPM pid %d did not exist" % tpm_pid)

tpm_thread.join()
tpm.stop()
log("tpm returned %d" % tpm.result.rc)
if tpm.result.rc != 0:
logging.warn("TPM errored: %s", tpm.result)
log("%s" % tpm.result)

if ret.rc != 0:
raise SubpError(ret)
Expand Down