Skip to content

Commit

Permalink
Fix up Justfile
Browse files Browse the repository at this point in the history
  • Loading branch information
charmoniumQ committed Jul 17, 2024
1 parent 687bc47 commit ca93ee5
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 91 deletions.
26 changes: 3 additions & 23 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,10 @@ jobs:
uses: actions/checkout@v2

- name: Set up Nix
uses: cachix/install-nix-action@v14
uses: cachix/install-nix-action@v27
with:
extra_nix_config: |
experimental-features = nix-command flakes
- name: Run Nix development shell
run: nix develop --command just format-nix

- name: Format Python code with Black
run: nix develop --command just format-python

- name: Check Python code with Ruff
run: nix develop --command just check-ruff

- name: Check Python code with Mypy
run: nix develop --command just check-mypy

- name: Run tests
run: nix develop --command just test

- name: Run flake checks (optional)
run: nix develop --command just flake-check
continue-on-error: true

- name: Run developer tests
if: github.event_name == 'push' && github.event.head_commit.message == 'run dev tests'
run: nix develop --command just test-dev
- name: Run Just in development shell
run: nix develop --command just on-push
46 changes: 30 additions & 16 deletions Justfile
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
format-nix:
alejandra .
fix-format-nix:
#alejandra .

format-python:
#black probe_src/probe_py
check-format-nix:
#alejandra --check .

fix-ruff:
#ruff format probe_src
ruff check --fix probe_src

check-ruff:
ruff check probe_src/probe_py
#ruff format --check probe_src
ruff check probe_src

check-mypy:
cd probe_src
mypy --package probe_py --strict
MYPYPATH=probe_src mypy --strict --package arena
MYPYPATH=probe_src mypy --strict --package probe_py
#mypy --strict probe_src/libprobe

compile-fresh-libprobe:
make --directory=probe_src/libprobe clean
make --directory=probe_src/libprobe all

compile-libprobe:
make --directory=probe_src/libprobe all

test:
#cd probe_src
#make --directory=libprobe all
#python -m pytest .
test: compile-fresh-libprobe
#cd probe_src && python -m pytest .

test-dev:
cd probe_src
make --directory=libprobe all
python -m pytest . --failed-first --maxfail=1
test-dev: compile-libprobe
make --directory=probe_src/libprobe all
#cd probe_src && python -m pytest . --failed-first --maxfail=1

flake-check:
check-flake:
nix flake check --all-systems

pre-commit: fix-format-nix fix-ruff check-mypy check-flake test-dev

on-push: check-format-nix check-ruff check-mypy check-flake test
4 changes: 2 additions & 2 deletions probe_src/libprobe/generator/dump_ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
import tempfile
import pycparser
import pycparser.c_generator
import pycparser # type: ignore
import pycparser.c_generator # type: ignore
import sys


Expand Down
155 changes: 105 additions & 50 deletions probe_src/libprobe/generator/gen_libc_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,94 @@

from __future__ import annotations
import dataclasses
import re
import pycparser
import pycparser.c_generator
import pycparser # type: ignore
import pycparser.c_generator # type: ignore
import typing
import tempfile
import pathlib
import collections.abc
import sys


class GccCGenerator(pycparser.c_generator.CGenerator):
_T = typing.TypeVar("_T")
def expect_type(typ: type[_T], data: typing.Any) -> _T:
if not isinstance(data, typ):
raise TypeError(f"Expected type {typ} for {data}")
return data


if typing.TYPE_CHECKING:
class CGenerator:
def _parenthesize_if(self, n: Node, condition: typing.Callable[[Node], bool]) -> str: ...
def _generate_decl(self, n: pycparser.c_ast.Node) -> str: ...
def visit(self, n: pycparser.c_ast.Node | str | list[str]) -> str: ...
def _visit_expr(self, n: pycparser.c_ast.Node) -> str: ...
def _make_indent(self) -> str: ...
indent_level: int
class Node:
pass
class IdentifierType(Node):
names: list[str]
def __init__(self, names: list[str]) -> None: ...
class Assignment(Node):
def __init__(self, op: str, lvalue: Node, rvalue: Node): ...
op: str
lvalue: Node
rvalue: Node
class Compound(Node):
block_items: list[Node]
class ID(Node):
name: str
class Decl(Node):
def __init__(self, name: str, quals: list[str], align: list[str], storage: list[str], funcspec: list[str], type: TypeDecl, init: Node | None, bitsize : Node | None) -> None: ...
name: str
quals: list[Node]
align: list[Node]
storage: list[Node]
funcspec: list[Node]
type: TypeDecl
init: Node | None
bitsize: Node | None
class TypeDecl(Node):
def __init__(self, declname: str, quals: list[Node], align: Node | None, type: Node) -> None: ...
declname: str
quals: list[Node]
align: Node | None
type: Node
class FuncDecl(Node):
args: ParamList
type: TypeDecl
class ParamList(Node):
params: list[Decl]
else:
CGenerator = pycparser.c_generator.CGenerator
Node = pycparser.c_ast.Node
IdentifierType = pycparser.c_ast.IdentifierType
Assignment = pycparser.c_ast.Assignment
Compound = pycparser.c_ast.Compound
Decl = pycparser.c_ast.Decl
TypeDecl = pycparser.c_ast.TypeDecl
ID = pycparser.c_ast.ID
FuncDecl = pycparser.c_ast.FuncDecl
ParamList = pycparser.c_ast.ParamList


class GccCGenerator(CGenerator):
"""A C generator that is able to emit gcc statement-expr ({...;})"""

def visit_Assignment(self, n):
def visit_Assignment(self, n: Assignment) -> str:
rval_str = self._parenthesize_if(
n.rvalue,
lambda n: isinstance(n, (pycparser.c_ast.Assignment, pycparser.c_ast.Compound)),
lambda n: isinstance(n, (Assignment, Compound)),
)
return '%s %s %s' % (self.visit(n.lvalue), n.op, rval_str)

def visit_Decl(self, n, no_type=False):
def visit_Decl(self, n: Decl, no_type: bool = False) -> str:
s = n.name if no_type else self._generate_decl(n)
if n.bitsize: s += ' : ' + self.visit(n.bitsize)
if n.bitsize:
s += ' : ' + self.visit(n.bitsize)
if n.init:
s += ' = ' + self._parenthesize_if(n.init, lambda n: isinstance(n, (pycparser.c_ast.Assignment, pycparser.c_ast.Compound)))
s += ' = ' + self._parenthesize_if(n.init, lambda n: isinstance(n, (Assignment, pycparser.c_ast.Compound)))
return s

def _parenthesize_if(self, n, condition):
def _parenthesize_if(self, n: Node, condition: typing.Callable[[Node], bool]) -> str:
self.indent_level += 2
s = self._visit_expr(n)
self.indent_level -= 2
Expand All @@ -42,12 +102,12 @@ def _parenthesize_if(self, n, condition):
return s


def is_void(node: pycparser.c_ast.Node) -> bool:
return isinstance(node.type, pycparser.c_ast.IdentifierType) and node.type.names[0] == "void"
def is_void(node: TypeDecl) -> bool:
return isinstance(node.type, IdentifierType) and node.type.names[0] == "void"


def define_var(var_type: pycparser.c_ast.Node, var_name: str, value: pycparser.c_ast.Node) -> pycparser.c_ast.Decl:
return pycparser.c_ast.Decl(
def define_var(var_type: Node, var_name: str, value: Node) -> Decl:
return Decl(
name=var_name,
quals=[],
align=[],
Expand All @@ -64,12 +124,12 @@ def define_var(var_type: pycparser.c_ast.Node, var_name: str, value: pycparser.c
)


void = pycparser.c_ast.IdentifierType(names=['void'])
void = IdentifierType(names=['void'])

c_ast_int = pycparser.c_ast.IdentifierType(names=['int'])
c_ast_int = IdentifierType(names=['int'])


def ptr_type(type: pycparser.c_ast.Node) -> pycparser.c_ast.PtrDecl:
def ptr_type(type: Node) -> pycparser.c_ast.PtrDecl:
return pycparser.c_ast.PtrDecl(
quals=[],
type=pycparser.c_ast.TypeDecl(
Expand Down Expand Up @@ -104,22 +164,22 @@ def ptr_type(type: pycparser.c_ast.Node) -> pycparser.c_ast.PtrDecl:
class ParsedFunc:
name: str
# Using tuples rather than lists since tuples are covariant
params: typing.Sequence[tuple[str, pycparser.c_ast.Node]]
return_type: pycparser.c_ast.Node
params: typing.Sequence[tuple[str, TypeDecl]]
return_type: TypeDecl
variadic: bool = False
stmts: typing.Sequence[pycparser.c_ast.Node] = ()
stmts: typing.Sequence[Node] = ()

@staticmethod
def from_decl(decl: pycparser.c_ast.Decl) -> ParsedFunc:
def from_decl(decl: Decl) -> ParsedFunc:
return ParsedFunc(
name=decl.name,
params=tuple(
(param_decl.name, param_decl.type)
for param_decl in decl.type.args.params
if isinstance(param_decl, pycparser.c_ast.Decl)
for param_decl in expect_type(FuncDecl, decl.type).args.params
if isinstance(param_decl, Decl)
),
return_type=decl.type.type,
variadic=isinstance(decl.type.args.params[-1], pycparser.c_ast.EllipsisParam),
return_type=expect_type(FuncDecl, decl.type).type,
variadic=isinstance(expect_type(FuncDecl, decl.type).args.params[-1], pycparser.c_ast.EllipsisParam),
)

@staticmethod
Expand All @@ -133,7 +193,7 @@ def declaration(self) -> pycparser.c_ast.FuncDecl:
return pycparser.c_ast.FuncDecl(
args=pycparser.c_ast.ParamList(
params=[
pycparser.c_ast.Decl(
Decl(
name=param_name,
quals=[],
align=[],
Expand All @@ -156,7 +216,7 @@ def declaration(self) -> pycparser.c_ast.FuncDecl:

def definition(self) -> pycparser.c_ast.FuncDef:
return pycparser.c_ast.FuncDef(
decl=pycparser.c_ast.Decl(
decl=Decl(
name=self.name,
quals=[],
align=[],
Expand All @@ -183,9 +243,9 @@ def definition(self) -> pycparser.c_ast.FuncDef:
funcs = {
**orig_funcs,
**{
node.name: dataclasses.replace(orig_funcs[node.init.name], name=node.name)
node.name: dataclasses.replace(orig_funcs[typing.cast(ID, node.init).name], name=node.name)
for node in ast.ext
if isinstance(node, pycparser.c_ast.Decl) and isinstance(node.type, pycparser.c_ast.TypeDecl) and node.type.type.names == ["fn"]
if isinstance(node, Decl) and isinstance(node.type, pycparser.c_ast.TypeDecl) and node.type.type.names == ["fn"]
},
}
# funcs = {
Expand All @@ -194,7 +254,7 @@ def definition(self) -> pycparser.c_ast.FuncDef:
# }
func_prefix = "unwrapped_"
func_pointer_declarations = [
pycparser.c_ast.Decl(
Decl(
name=func_prefix + func_name,
quals=[],
align=[],
Expand All @@ -212,10 +272,10 @@ def definition(self) -> pycparser.c_ast.FuncDef:
init_function_pointers = ParsedFunc(
name="init_function_pointers",
params=(),
return_type=void,
return_type=TypeDecl(declname="a", quals=[], align=None, type=void),
variadic=False,
stmts=[
pycparser.c_ast.Assignment(
Assignment(
op='=',
lvalue=pycparser.c_ast.ID(name=func_prefix + func_name),
rvalue=pycparser.c_ast.FuncCall(
Expand Down Expand Up @@ -243,24 +303,24 @@ def raise_thunk(exception: Exception) -> typing.Callable[..., typing.NoReturn]:


def find_decl(
block: typing.Sequence[pycparser.c_ast.Node],
block: typing.Sequence[Node],
name: str,
comment: typing.Any,
) -> pycparser.c_ast.Decl | None:
) -> Decl | None:
relevant_stmts = [
stmt
for stmt in block
if isinstance(stmt, pycparser.c_ast.Decl) and stmt.name == name
if isinstance(stmt, Decl) and stmt.name == name
]
if not relevant_stmts:
None
return None
elif len(relevant_stmts) > 1:
raise ValueError(f"Multiple definitions of {name}" + " ({})".format(comment) if comment else "")
else:
return relevant_stmts[0]


def wrapper_func_body(func: ParsedFunc) -> typing.Sequence[pycparser.c_ast.Node]:
def wrapper_func_body(func: ParsedFunc) -> typing.Sequence[Node]:
pre_call_stmts = [
pycparser.c_ast.FuncCall(
name=pycparser.c_ast.ID(name="maybe_init_thread"),
Expand All @@ -271,21 +331,16 @@ def wrapper_func_body(func: ParsedFunc) -> typing.Sequence[pycparser.c_ast.Node]

pre_call_action = find_decl(func.stmts, "pre_call", func.name)
if pre_call_action:
if isinstance(pre_call_action.init, pycparser.c_ast.Compound):
if isinstance(pre_call_action.init, Compound):
pre_call_stmts.extend(pre_call_action.init.block_items)
else:
pre_call_stmts.append(pre_call_action.init);

prov_log_is_enabled = pycparser.c_ast.FuncCall(
name=pycparser.c_ast.ID(name="prov_log_is_enabled"),
args=pycparser.c_ast.ExprList(exprs=[]),
)
pre_call_stmts.append(pre_call_action.init)

post_call_action = find_decl(func.stmts, "post_call", func.name)

if post_call_action:
post_call_stmts.extend(
post_call_action.init.block_items,
expect_type(Compound, post_call_action.init).block_items,
)

call_stmts_block = find_decl(func.stmts, "call", func.name)
Expand Down Expand Up @@ -338,10 +393,10 @@ def wrapper_func_body(func: ParsedFunc) -> typing.Sequence[pycparser.c_ast.Node]
else:
call_stmts = [define_var(func.return_type, "ret", call_expr)]
else:
call_stmts = call_stmts_block.init.block_items
call_stmts = expect_type(Compound, call_stmts_block.init).block_items

save_errno = define_var(c_ast_int, "saved_errno", pycparser.c_ast.ID(name="errno"))
restore_errno = pycparser.c_ast.Assignment(
restore_errno = Assignment(
op='=',
lvalue=pycparser.c_ast.ID(name="errno"),
rvalue=pycparser.c_ast.ID(name="saved_errno"),
Expand Down

0 comments on commit ca93ee5

Please sign in to comment.