Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: generate test code from schema docstrings #17396

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion misc/codegen/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen


def generate(target, opts, renderer):
Expand Down
4 changes: 2 additions & 2 deletions misc/codegen/generators/qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)


def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_skip" in cls.pragmas or not (
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
cls, lookup)
Expand Down Expand Up @@ -413,7 +413,7 @@ def generate(opts, renderer):

if test_out:
for c in data.classes.values():
if _should_skip_qltest(c, data.classes):
if should_skip_qltest(c, data.classes):
continue
test_with = data.classes[c.test_with] if c.test_with else c
test_dir = test_out / test_with.group / test_with.name
Expand Down
32 changes: 18 additions & 14 deletions misc/codegen/generators/rustgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
else:
table_name = inflection.tableize(table_name)
args = dict(
field_name=p.name + ("_" if p.name in rust.keywords else ""),
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
Expand Down Expand Up @@ -86,20 +86,24 @@ def generate(opts, renderer):
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ClassList(
classes,
rust.ModuleList(
groups,
opts.schema,
),
out / f"{group}.rs",
out / f"mod.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)
64 changes: 64 additions & 0 deletions misc/codegen/generators/rusttestgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import dataclasses
import typing
import inflection

from misc.codegen.loaders import schemaloader
from . import qlgen


@dataclasses.dataclass
class Param:
name: str
type: str
first: bool = False


@dataclasses.dataclass
class Function:
name: str
signature: str


@dataclasses.dataclass
class TestCode:
template: typing.ClassVar[str] = "rust_test_code"

code: str
function: Function | None = None


def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc):
continue
code = []
adding_code = False
has_code = False
for line in cls.doc:
match line, adding_code:
case "```", _:
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.rust_doc_test_function
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)
8 changes: 6 additions & 2 deletions misc/codegen/lib/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
"try",
}


def avoid_keywords(s: str) -> str:
return s + "_" if s in keywords else s


_field_overrides = [
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
Expand All @@ -82,8 +87,7 @@ class Field:
first: bool = False

def __post_init__(self):
if self.field_name in keywords:
self.field_name += "_"
self.field_name = avoid_keywords(self.field_name)

@property
def type(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions misc/codegen/lib/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Class:
default_doc_name: Optional[str] = None
hideable: bool = False
test_with: Optional[str] = None
rust_doc_test_function: Optional["FunctionInfo"] = None # TODO: parametrized pragmas

@property
def final(self):
Expand Down
5 changes: 5 additions & 0 deletions misc/codegen/lib/schemadefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def modify(self, prop: _schema.Property):
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
rust = _Namespace()
synth = _SynthModifier()


Expand Down Expand Up @@ -156,6 +157,10 @@ def f(cls: type) -> type:

_Pragma("cpp_skip")

_Pragma("rust_skip_doc_test")

rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)


def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
Expand Down
1 change: 1 addition & 0 deletions misc/codegen/loaders/schemaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_class(cls: type) -> schema.Class:
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function")
)


Expand Down
9 changes: 9 additions & 0 deletions misc/codegen/templates/rust_test_code.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// generated by {{generator}}

{{#function}}
fn {{name}}{{signature}} {
{{/function}}
{{code}}
{{#function}}
}
{{/function}}
Loading
Loading