Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
bobot committed Jun 17, 2024
1 parent 15b7725 commit a6952ac
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 17 deletions.
39 changes: 32 additions & 7 deletions smtcomp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
from rich.console import Console
import smtcomp.test_solver as test_solver
from concurrent.futures import ThreadPoolExecutor

from smtcomp.benchexec import get_suffix
from smtcomp.scramble_benchmarks import benchmark_files_dir
from smtcomp.utils import *
import re

app = typer.Typer()

Expand Down Expand Up @@ -551,13 +554,35 @@ def read_submission(file: Path) -> defs.Submission:


@app.command()
def check_model_locally(cachedir: Path, resultdirs: list[Path], max_workers: int = 8) -> None:
l: list[tuple[results.RunId, results.Run, model_validation.ValidationResult]] = []
def check_model_locally(
cachedir: Path, resultdirs: list[Path], max_workers: int = 8, outdir: Optional[Path] = None
) -> None:
l: list[tuple[results.RunId, results.Run, model_validation.ValidationError]] = []
with Progress() as progress:
with ThreadPoolExecutor(max_workers) as executor:
for resultdir in resultdirs:
l2 = model_validation.check_results_locally(cachedir, resultdir, executor, progress)
for rid, r, result in l2:
if result.status != defs.Status.Sat:
l.append((rid, r, result))
print(l)
l.extend(filter_map(map_none3(model_validation.is_error), l2))
keyfunc = lambda v: v[0].solver
l.sort(key=keyfunc)
d = itertools.groupby(l, key=keyfunc)
t = Tree("Unvalidated models")
for solver, rs in d:
t2 = t.add(solver)
for rid, r, result in rs:
stderr = result.stderr.strip().replace("\n", ", ")
t2.add(f"{r.basename}: {stderr}")
print(t)
if outdir is not None:
for solver, models in d:
dst = outdir / solver
dst.mkdir(parents=True, exist_ok=True)
for rid, r, result in models:
filedir = benchmark_files_dir(cachedir, rid.track)
logic = rid.includefile.removesuffix(get_suffix(rid.track))
basename = r.basename.removesuffix(".yml") + ".smt2"
basename_model = r.basename.removesuffix(".yml") + ".rsmt2"
smt2_file = filedir / logic / basename
(dst / basename).unlink(missing_ok=True)
(dst / basename).symlink_to(smt2_file)
(dst / basename_model).write_text(result.model)
42 changes: 32 additions & 10 deletions smtcomp/model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,36 @@


@dataclass
class ValidationResult:
class ValidationOk:
stderr: str


@dataclass
class ValidationError:
status: defs.Status
stderr: str
model: str | None


@dataclass
class NoValidation:
"""No validation possible"""


noValidation = NoValidation()

Validation = ValidationOk | ValidationError | NoValidation


def is_error(x: Validation) -> ValidationError | None:
match x:
case ValidationError(_):
return x
case ValidationOk(_) | NoValidation():
return None


def check_locally(smt2_file: Path, model: str) -> ValidationResult:
def check_locally(smt2_file: Path, model: str) -> Validation:
r = subprocess.run(
[
"dolmen",
Expand All @@ -54,39 +78,37 @@ def check_locally(smt2_file: Path, model: str) -> ValidationResult:
)
match r.returncode:
case 0:
status = defs.Status.Sat
return ValidationOk(r.stderr.decode())
case 5:
status = defs.Status.Unsat
case 2:
# LimitReached
status = defs.Status.Unknown
case _:
status = defs.Status.Unknown
return ValidationResult(status, r.stderr.decode())
return ValidationError(status, r.stderr.decode(), model)


def check_result_locally(
cachedir: Path, logfiles: results.LogFile, rid: results.RunId, r: results.Run
) -> ValidationResult:
def check_result_locally(cachedir: Path, logfiles: results.LogFile, rid: results.RunId, r: results.Run) -> Validation:
if r.status == "true":
filedir = benchmark_files_dir(cachedir, rid.track)
logic = rid.includefile.removesuffix(get_suffix(rid.track))
smt2_file = filedir / logic / (r.basename.removesuffix(".yml") + ".smt2")
model = logfiles.get_output(rid, r.basename)
return check_locally(smt2_file, model)
else:
return ValidationResult(defs.Status.Unknown, "")
return noValidation


def check_results_locally(
cachedir: Path, resultdir: Path, executor: ThreadPoolExecutor, progress: Progress
) -> list[tuple[results.RunId, results.Run, ValidationResult]]:
) -> list[tuple[results.RunId, results.Run, Validation]]:
with results.LogFile(resultdir) as logfiles:
l = [(r.runid, b) for r in results.parse_results(resultdir) for b in r.runs if b.status == "true"]
return list(
progress.track(
executor.map((lambda v: (v[0], v[1], check_result_locally(cachedir, logfiles, v[0], v[1]))), l),
total=len(l),
description="checking models",
description=f"checking models for {resultdir.name}",
)
)
23 changes: 23 additions & 0 deletions smtcomp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import *

U = TypeVar("U")
V = TypeVar("V")
W1 = TypeVar("W1")
W2 = TypeVar("W2")


def filter_map(f: Callable[[U], V | None], i: Iterable[U]) -> Iterable[V]:
i2 = map(f, i)
i3: Iterable[V | None] = filter(lambda x: x is not None, i2)
return cast(Iterable[V], i3)


def map_none3(f: Callable[[U], V | None]) -> Callable[[Tuple[W1, W2, U]], Tuple[W1, W2, V] | None]:
def g(x: Tuple[W1, W2, U]) -> Tuple[W1, W2, V] | None:
y = f(x[2])
if y is None:
return None
else:
return (x[0], x[1], y)

return g

0 comments on commit a6952ac

Please sign in to comment.