Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup/separate class for swtpm2 #38

Merged
merged 7 commits into from
Jul 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 161 additions & 69 deletions test/harness
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import subprocess
import sys
import threading
import tempfile
import textwrap
import time
import yaml

Expand Down Expand Up @@ -182,7 +183,7 @@ class SubpResult():
def __str__(self):
def indent(name, data):
if data is None:
return "{name}: {data}"
return f"{name}:"

if hasattr(data, 'decode'):
data = data.decode('utf-8', errors='ignore')
Expand Down Expand Up @@ -257,27 +258,8 @@ def subp(cmd, capture=True, data=None, rcs=(0,), timeout=None, ksignal=signal.SI
sp = subprocess.Popen(
cmd, stdout=stdout, stderr=stderr, stdin=stdin, cwd=cwd)

def communicate():
try:
(result.out, result.err) = sp.communicate(data)
except Exception as e:
result.exception = e

if timeout is None and ksignal == signal.SIGKILL:
communicate()
else:
thread = threading.Thread(target=communicate)
thread.start()
thread.join(timeout)
if thread.is_alive():
result.exception = subprocess.TimeoutExpired(cmd=cmd, timeout=timeout)
sp.send_signal(ksignal)
thread.join(5)
if thread.is_alive():
print("%s didn't die on %s, sending SIGTERM." %
(cmd[0], ksignal))
sp.send_signal(signal.SIGTERM)
thread.join()
result.out, result.err, result.exception = do_communicate(
proc=sp, data=data, timeout=timeout, sig=ksignal, delay=5)

if result.exception is None:
result.rc = sp.returncode
Expand Down Expand Up @@ -458,6 +440,11 @@ def smash(output, stubefi, kernel, initrd, sbat, cmdline):
os.unlink(cmdline_f)


def ensure_dirs(dlist):
for d in dlist:
ensure_dir(d)


def ensure_dir(mdir):
if os.path.isdir(mdir):
return
Expand Down Expand Up @@ -538,6 +525,121 @@ def get_env_boolean(name, default):
raise ValueError(f"environment var {name} had value '{val}'. Expected 'true' or 'false'")


def do_communicate(proc, data=None, timeout=None, sig=signal.SIGTERM, delay=5):
exc, stdout, stderr = (None, None, None)
try:
stdout, stderr = proc.communicate(input=data, timeout=timeout)
except subprocess.TimeoutExpired as e:
exc = e
proc.send_signal(sig)
try:
stdout, stderr = proc.communicate(timeout=delay)
except subprocess.TimeoutExpired:
if sig != signal.SIGKILL:
proc.send_signal(signal.SIGKILL)
stdout, stderr = proc.communicate()
except Exception as e:
exc = e

return stdout, stderr, exc


class swtpm2:
socket_timeout = 10
sigterm_timeout = 5

result, sp, runthread, start_time = (None, None, None, None)

def __init__(self, state_d, socket=None, log=None, timeout=None):
self.state_d = state_d
self.socket = socket if socket else path_join(self.state_d, "socket")
self.log = log if log else path_join(self.state_d, "log")
self.timeout = timeout

def cmd(self):
return ["swtpm", "socket", "--tpmstate=dir=" + self.state_d,
"--ctrl=type=unixio,path=" + self.socket,
"--log=level=20,file=" + self.log,
"--tpm2"]

def _communicate(self):
res = self.result
res.out, res.err, res.exc = do_communicate(
self.sp, timeout=self.timeout, delay=self.sigterm_timeout)
res.duration = time.time() - self.start_time
res.rc = self.sp.returncode
self.sp = None

def start(self):
ensure_dirs([
os.path.dirname(self.socket),
os.path.dirname(self.log),
self.state_d])

try:
os.remove(self.socket)
except FileNotFoundError:
pass

self.start_time = time.time()
self.result = SubpResult(cmd=self.cmd(), rc=-1)
try:
self.sp = subprocess.Popen(
self.result.cmd,
stdout=subprocess.PIPE, stderr = subprocess.PIPE)
except FileNotFoundError as e:
self.result.duration = time.time() - self.start_time
self.result.exception = e
raise SubpError(self.result)

self.runthread = threading.Thread(target=self._communicate)
self.runthread.start()

# wait for the socket to appear to signal system ready
while True:
if os.path.exists(self.socket):
break
if not self.runthread.is_alive():
raise SubpError(self.result)
if time.time() - self.start_time > self.socket_timeout:
self.stop()
logging.warn(
"swtpm process did not create socket %s within %ds "
"after starting", self.socket, self.socket_timeout)
raise TimeoutError("Timeout waiting for swtpm to create socket")

time.sleep(.1)


def stop(self, sig=signal.SIGTERM, delay=None):
if delay is None:
delay = self.sigterm_timeout

if self.runthread is None:
logging.debug("Called stop, but no runthread")
return
if self.runthread.is_alive():
if self.sp:
self.sp.send_signal(sig)
self.runthread.join(delay)
if self.runthread.is_alive():
logging.debug(
"swtpm process did not exit within %d seconds after %s"
", sending SIGKILL", delay, sig.NAME)
if self.sp:
self.sp.send_signal(signal.SIGKILL)
self.runthread.join()

self.runthread = None

def __repr__(self):
return " ".join([
"swtpm(",
"pid = %d" % (-1 if self.sp is None else self.sp.pid),
")",
])


class Runner:
def __init__(self, cliargs, workdir=None):
self.cleanup_workdir = True
Expand Down Expand Up @@ -609,7 +711,7 @@ class Runner:
jobq.task_done()

# Turn-on the worker thread.
for i in range(self.num_threads):
for _ in range(self.num_threads):
threading.Thread(target=worker, daemon=True).start()

for testdata in self.tests:
Expand Down Expand Up @@ -639,17 +741,17 @@ class Runner:
ensure_dir(results_d)

rlogfp = open(path_join(results_d, "run.log"), "w")
start = time.time()
def runlog(msg):
rlogfp.write(msg + "\n")
rlogfp.write("[%7.3f] " % (time.time() - start) + msg + "\n")
rlogfp.flush()

qemu_log = path_join(results_d, "qemu.log")
serial_log = path_join(results_d, "serial.log")
start = time.time()
print(f"Starting {name} in {run_d}\n qemu_log: {qemu_log}\n serial: {serial_log}.raw")

try:
self._run(testdata, run_d, qemu_log, serial_log, runlog)
self._run(testdata, run_d, results_d, qemu_log, serial_log, runlog)
except SubpError as e:
if isinstance(e.result.exception, subprocess.TimeoutExpired):
result = "TIMEOUT: " + str(e)
Expand All @@ -666,10 +768,10 @@ class Runner:
res = result.split(":")[0]
print("Finished %s [%.2fs]: %s" % (name, time.time() - start, res))

def _run(self, testdata, run_d, qemu_log, serial_log, log):
name = testdata["name"]
def _run(self, testdata, run_d, results_d, qemu_log, serial_log, log):
esp = path_join(run_d, "esp.raw")

log("generating esp")
gen_esp(
esp, run_d=run_d, mode=self.boot_mode,
kernel=self.kernel, initrd=self.initrd,
Expand All @@ -690,6 +792,7 @@ class Runner:
rel_ocode_src = path_join("..", os.path.basename(ocode_src))
rel_esp = os.path.basename(esp)

tpmd = "./tpm"
cmd_base = [
"qemu-system-x86_64",
"-M", "q35,smm=on" + (",accel=kvm" if self.kvm else ""),
Expand All @@ -704,7 +807,7 @@ class Runner:
"-drive", f"if=pflash,format=raw,file={rel_ovars},snapshot=on",
"-drive", f"file={rel_esp},id=disk00,if=none,format=raw,index=0,snapshot=on",
"-device", "virtio-blk,drive=disk00,serial=esp-image",
"-chardev", "socket,id=chrtpm,path=./tpm/socket",
"-chardev", "socket,id=chrtpm,path=" + path_join(tpmd, "socket"),
"-tpmdev", "emulator,id=tpm0,chardev=chrtpm",
"-device", "tpm-tis,tpmdev=tpm0"]

Expand All @@ -713,28 +816,35 @@ class Runner:
"-vnc", "none",
]

ensure_dir(path_join(run_d, "tpm"))
tpm_cmd = ["swtpm", "socket", "--tpmstate=dir=./tpm",
"--ctrl=type=unixio,path=./tpm/socket",
"--log=level=20,file=./tpm/log",
"--pid=file=./tpm/pid", "--tpm2"]

cmd_inter = cmd_base + ["-serial", "mon:stdio"]
bootscript = textwrap.dedent("""\
#!/bin/sh
mkdir -p {tpmd}
rm -f {tpmd}/socket
{tpmcmd} &
while ! [ -e "{tpmd}/socket" ]; do sleep .1; done
{qemucmd}
r=$?
wait
exit $r
""")

write_file(path_join(run_d, "boot"),
"#!/bin/sh\n" +
" ".join(tpm_cmd) + " &\n" +
" ".join(cmd_inter) + "\n" +
"r=$?\n" + "wait\n" + "exit $r\n")
bootscript.format(
tpmd=tpmd,
tpmcmd=shell_quote(swtpm2(tpmd).cmd()),
qemucmd=shell_quote(cmd_base + ["-serial", "mon:stdio"])))

timeout = self.timeout

tpm_thread = threading.Thread(
target=subp, args=(tpm_cmd,), daemon=True,
kwargs={"cwd": run_d, "rcs": None, "capture": False, "timeout": timeout})
tpm_thread.start()
tpm = swtpm2(
state_d=path_join(run_d, "tpm"),
log=path_join(results_d, "tpm.log"),
timeout=self.timeout)

log("starting tpm")
tpm.start()

log("started tpm, starting qemu")
log("tpm ready, starting qemu")
with open(qemu_log, "wb") as qfp:
qfp.write(b"# " + b" ".join([s.encode("utf-8") for s in cmd]) + b"\n")
qfp.flush()
Expand All @@ -751,31 +861,13 @@ class Runner:
if not rawlog.endswith(b"\n"):
wfp.write(b"\n")

# If the tpm subprocess is still around, give it a SIGTERM
if not tpm_thread.is_alive():
log("TPM thread already exited")
else:
tpm_pidf = path_join(run_d, "tpm", "pid")
tpm_pid = 0
if os.path.exists(tpm_pidf):
pidstr = read_file(tpm_pidf).strip()
try:
tpm_pid = int(pidstr)
except ValueError:
log("WARN: Read non-integer '%s' from tpm pidfile %s" %
(pidstr, tpm_pidf))
else:
log("WARN: TPM Pidfile %s did not exist" % (tpm_pidf))

if tpm_pid != 0:
log("sending SIGTERM to TPM pid %d" % tpm_pid)
try:
os.kill(tpm_pid, signal.SIGTERM)
except ProcessLookupError:
log("TPM pid %d did not exist" % tpm_pid)

tpm_thread.join()
tpm.stop()
log("tpm returned %d" % tpm.result.rc)
if tpm.result.rc != 0:
logging.warn("TPM errored: %s", tpm.result)
log("%s" % tpm.result)

log("finished")
if ret.rc != 0:
raise SubpError(ret)

Expand Down