Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: generate test code from schema docstrings #17396

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion misc/codegen/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import dbschemegen, qlgen, trapgen, cppgen, rustgen
from . import dbschemegen, trapgen, cppgen, rustgen, rusttestgen, qlgen


def generate(target, opts, renderer):
Expand Down
4 changes: 2 additions & 2 deletions misc/codegen/generators/qlgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _is_under_qltest_collapsed_hierarchy(cls: schema.Class, lookup: typing.Dict[
_is_in_qltest_collapsed_hierarchy(lookup[b], lookup) for b in cls.bases)


def _should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
def should_skip_qltest(cls: schema.Class, lookup: typing.Dict[str, schema.Class]):
return "qltest_skip" in cls.pragmas or not (
cls.final or "qltest_collapse_hierarchy" in cls.pragmas) or _is_under_qltest_collapsed_hierarchy(
cls, lookup)
Expand Down Expand Up @@ -413,7 +413,7 @@ def generate(opts, renderer):

if test_out:
for c in data.classes.values():
if _should_skip_qltest(c, data.classes):
if should_skip_qltest(c, data.classes):
continue
test_with = data.classes[c.test_with] if c.test_with else c
test_dir = test_out / test_with.group / test_with.name
Expand Down
32 changes: 18 additions & 14 deletions misc/codegen/generators/rustgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
else:
table_name = inflection.tableize(table_name)
args = dict(
field_name=p.name + ("_" if p.name in rust.keywords else ""),
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
Expand Down Expand Up @@ -86,20 +86,24 @@ def generate(opts, renderer):
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ClassList(
classes,
rust.ModuleList(
groups,
opts.schema,
),
out / f"{group}.rs",
out / f"mod.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)
64 changes: 64 additions & 0 deletions misc/codegen/generators/rusttestgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import dataclasses
import typing
import inflection

from misc.codegen.loaders import schemaloader
from . import qlgen


@dataclasses.dataclass
class Param:
name: str
type: str
first: bool = False


@dataclasses.dataclass
class Function:
name: str
signature: str


@dataclasses.dataclass
class TestCode:
template: typing.ClassVar[str] = "rust_test_code"

code: str
function: Function | None = None


def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.generated_registry,
force=opts.force) as renderer:
for cls in schema.classes.values():
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_test_from_doc" in cls.pragmas or
not cls.doc):
continue
code = []
adding_code = False
has_code = False
for line in cls.doc:
match line, adding_code:
case "```", _:
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
if not has_code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.rust_doc_test_function
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with = schema.classes[cls.test_with] if cls.test_with else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)
8 changes: 6 additions & 2 deletions misc/codegen/lib/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@
"try",
}


def avoid_keywords(s: str) -> str:
return s + "_" if s in keywords else s


_field_overrides = [
(re.compile(r"(.*)_"), lambda m: {"field_name": m[1]}),
]
Expand All @@ -82,8 +87,7 @@ class Field:
first: bool = False

def __post_init__(self):
if self.field_name in keywords:
self.field_name += "_"
self.field_name = avoid_keywords(self.field_name)

@property
def type(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions misc/codegen/lib/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Class:
default_doc_name: Optional[str] = None
hideable: bool = False
test_with: Optional[str] = None
rust_doc_test_function: Optional["FunctionInfo"] = None # TODO: parametrized pragmas

@property
def final(self):
Expand Down
5 changes: 5 additions & 0 deletions misc/codegen/lib/schemadefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def modify(self, prop: _schema.Property):
qltest = _Namespace()
ql = _Namespace()
cpp = _Namespace()
rust = _Namespace()
synth = _SynthModifier()


Expand Down Expand Up @@ -156,6 +157,10 @@ def f(cls: type) -> type:

_Pragma("cpp_skip")

_Pragma("rust_skip_doc_test")

rust.doc_test_signature = lambda signature: _annotate(rust_doc_test_function=signature)


def group(name: str = "") -> _ClassDecorator:
return _annotate(group=name)
Expand Down
1 change: 1 addition & 0 deletions misc/codegen/loaders/schemaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_class(cls: type) -> schema.Class:
],
doc=schema.split_doc(cls.__doc__),
default_doc_name=cls.__dict__.get("_doc_name"),
rust_doc_test_function=cls.__dict__.get("_rust_doc_test_function")
)


Expand Down
9 changes: 9 additions & 0 deletions misc/codegen/templates/rust_test_code.mustache
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// generated by {{generator}}

{{#function}}
fn {{name}}{{signature}} {
{{/function}}
{{code}}
{{#function}}
}
{{/function}}
12 changes: 5 additions & 7 deletions rust/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 3 additions & 5 deletions rust/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/codegen.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# configuration file for Swift code generation default options
--generate=dbscheme,ql,rust
--generate=dbscheme,rusttest,ql,rust
--dbscheme=ql/lib/rust.dbscheme
--ql-output=ql/lib/codeql/rust/generated
--ql-stub-output=ql/lib/codeql/rust/elements
Expand Down
13 changes: 13 additions & 0 deletions rust/ql/lib/codeql/rust/elements/Function.qll
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@

private import codeql.rust.generated.Function

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A function declaration. For example
* ```
* fn foo(x: u32) -> u64 { (x + 1).into() }
* ```
* A function declaration within a trait might not have a body:
* ```
* trait Trait {
* fn bar();
* }
* ```
*/
class Function extends Generated::Function {
override string toString() { result = this.getName() }
}
10 changes: 10 additions & 0 deletions rust/ql/lib/codeql/rust/generated/Function.qll
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ import codeql.rust.elements.Declaration
*/
module Generated {
/**
* A function declaration. For example
* ```
* fn foo(x: u32) -> u64 { (x + 1).into() }
* ```
* A function declaration within a trait might not have a body:
* ```
* trait Trait {
* fn bar();
* }
* ```
* INTERNAL: Do not reference the `Generated::Function` class directly.
* Use the subclass `Function`, where the following predicates are available.
*/
Expand Down
10 changes: 10 additions & 0 deletions rust/ql/lib/codeql/rust/generated/Raw.qll
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ module Raw {

/**
* INTERNAL: Do not use.
* A function declaration. For example
* ```
* fn foo(x: u32) -> u64 { (x + 1).into() }
* ```
* A function declaration within a trait might not have a body:
* ```
* trait Trait {
* fn bar();
* }
* ```
*/
class Function extends @function, Declaration {
override string toString() { result = "Function" }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
| test.rs:0:0:0:11 | foo | getName: | foo |
| test.rs:1:0:1:11 | bar | getName: | bar |
| test.rs:2:0:2:11 | baz | getName: | baz |
| gen_function.rs:2:0:3:40 | foo | getName: | foo |
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// generated by codegen

// A function declaration. For example
fn foo(x: u32) -> u64 { (x + 1).into() }
// A function declaration within a trait might not have a body:
trait Trait {
fn bar();
}
3 changes: 0 additions & 3 deletions rust/ql/test/extractor-tests/generated/Function/test.rs

This file was deleted.

12 changes: 12 additions & 0 deletions rust/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,16 @@ class Module(Declaration):
declarations: list[Declaration] | child

class Function(Declaration):
"""
A function declaration. For example
```
fn foo(x: u32) -> u64 { (x + 1).into() }
```
A function declaration within a trait might not have a body:
```
trait Trait {
fn bar();
}
```
"""
name: string
Loading