Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-8965: add subclasses to module generation #4491

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-module-generation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
{ subtype: "arm", type: "component" },
{ subtype: "audio_input", type: "component" },
{ subtype: "base", type: "component" },
{ subtype: "board", type: "component" },
{ subtype: "camera", type: "component" },
{ subtype: "encoder", type: "component" },
{ subtype: "gantry", type: "component" },
Expand Down Expand Up @@ -50,7 +51,6 @@ jobs:
- name: Run module
run: |
cd my-module
chmod +x run.sh
./run.sh /tmp/viam.sock &
PID=$!
sleep 5
Expand Down
1 change: 1 addition & 0 deletions cli/module_generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ func promptUser() (*common.ModuleInputs, error) {
huh.NewOption("Arm Component", "arm component"),
huh.NewOption("Audio Input Component", "audio_input component"),
huh.NewOption("Base Component", "base component"),
huh.NewOption("Board Component", "board component"),
huh.NewOption("Camera Component", "camera component"),
huh.NewOption("Encoder Component", "encoder component"),
huh.NewOption("Gantry Component", "gantry component"),
Expand Down
58 changes: 52 additions & 6 deletions cli/module_generate/scripts/generate_stubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,22 @@ func setGoModuleTemplate(clientCode string, module common.ModuleInputs) (*common
}
var functions []string
ast.Inspect(node, func(n ast.Node) bool {
if typeSpec, ok := n.(*ast.TypeSpec); ok {
if _, ok := typeSpec.Type.(*ast.StructType); ok {
if strings.Contains(typeSpec.Name.Name, "Client") {
jckras marked this conversation as resolved.
Show resolved Hide resolved
functions = append(functions, formatStruct(typeSpec, module.ModuleCamel+module.ModelPascal))
}
}
}
if funcDecl, ok := n.(*ast.FuncDecl); ok {
name, args, returns := parseFunctionSignature(module.ResourceSubtype, module.ResourceSubtypePascal, funcDecl)
name, receiver, args, returns := parseFunctionSignature(
module.ResourceSubtype,
module.ResourceSubtypePascal,
module.ModuleCamel+module.ModelPascal,
funcDecl,
)
if name != "" {
functions = append(functions, formatEmptyFunction(module.ModuleCamel+module.ModelPascal, name, args, returns))
functions = append(functions, formatEmptyFunction(receiver, name, args, returns))
}
}
return true
Expand Down Expand Up @@ -130,8 +142,22 @@ func handleMapType(str, resourceSubtype string) string {
return fmt.Sprintf("map[%s]%s", keyType, valueType)
}

func formatStruct(typeSpec *ast.TypeSpec, modelType string) string {
var buf bytes.Buffer
err := printer.Fprint(&buf, token.NewFileSet(), typeSpec)
if err != nil {
return fmt.Sprintf("Error formatting type: %v", err)
}
return "type " + strings.ReplaceAll(buf.String(), "*client", "*"+modelType) + "\n\n"
}

// parseFunctionSignature parses function declarations into the function name, the arguments, and the return types.
func parseFunctionSignature(resourceSubtype, resourceSubtypePascal string, funcDecl *ast.FuncDecl) (name, args string, returns []string) {
func parseFunctionSignature(
resourceSubtype,
resourceSubtypePascal string,
modelType string,
funcDecl *ast.FuncDecl,
) (name, receiver, args string, returns []string) {
if funcDecl == nil {
return
}
Expand All @@ -145,15 +171,35 @@ func parseFunctionSignature(resourceSubtype, resourceSubtypePascal string, funcD
return
}

// Receiver
receiver = modelType
if funcDecl.Recv != nil && len(funcDecl.Recv.List) > 0 {
field := funcDecl.Recv.List[0]
if starExpr, ok := field.Type.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
if ident.Name != "client" {
receiver = ident.Name
}
}
}
}

// Parameters
var params []string
if funcDecl.Type.Params != nil {
for _, param := range funcDecl.Type.Params.List {
paramType := formatType(param.Type)
if unicode.IsUpper(rune(paramType[0])) {

// Check if `paramType` is a type that is capitalized.
// If so, attribute the type to <resourceSubtype>.
switch {
case unicode.IsUpper(rune(paramType[0])):
paramType = fmt.Sprintf("%s.%s", resourceSubtype, paramType)
} else if strings.HasPrefix(paramType, "[]") && unicode.IsUpper(rune(paramType[2])) {
// IF `paramType` has a prefix, check if type is capitalized after prefix.
case strings.HasPrefix(paramType, "[]") && unicode.IsUpper(rune(paramType[2])):
paramType = fmt.Sprintf("[]%s.%s", resourceSubtype, paramType[2:])
case strings.HasPrefix(paramType, "chan ") && unicode.IsUpper(rune(paramType[5])):
paramType = fmt.Sprintf("chan %s.%s", resourceSubtype, paramType[5:])
stuqdog marked this conversation as resolved.
Show resolved Hide resolved
}

for _, name := range param.Names {
Expand Down Expand Up @@ -199,7 +245,7 @@ func parseFunctionSignature(resourceSubtype, resourceSubtypePascal string, funcD
}
}

return funcName, strings.Join(params, ", "), returns
return funcName, receiver, strings.Join(params, ", "), returns
}

// formatEmptyFunction outputs the new function that removes the function body, adds the panic unimplemented statement,
Expand Down
152 changes: 104 additions & 48 deletions cli/module_generate/scripts/generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,92 @@
import subprocess
import sys
from importlib import import_module
from typing import List, Set


def return_attribute(resource_name: str, attr: str) -> ast.Attribute:
def return_attribute(value: str, attr: str) -> ast.Attribute:
jckras marked this conversation as resolved.
Show resolved Hide resolved
return ast.Attribute(
value=ast.Name(id=resource_name, ctx=ast.Load()),
value=ast.Name(id=value, ctx=ast.Load()),
attr=attr,
ctx=ast.Load())


def update_annotation(
resource_name: str,
annotation: ast.Name | ast.Subscript,
nodes: Set[str],
parent: str
) -> ast.Attribute | ast.Subscript:
if isinstance(annotation, ast.Name) and annotation.id in nodes:
value = parent if parent else resource_name
return return_attribute(value, annotation.id)
elif isinstance(annotation, ast.Subscript):
annotation.slice = update_annotation(
resource_name,
annotation.slice,
nodes,
parent)
return annotation


def replace_async_func(
resource_name: str,
func: ast.AsyncFunctionDef,
nodes: Set[str],
parent: str = ""
) -> None:
for arg in func.args.args:
arg.annotation = update_annotation(
resource_name,
arg.annotation,
nodes,
parent)
func.body = [
ast.Raise(
exc=ast.Call(func=ast.Name(id='NotImplementedError',
ctx=ast.Load()),
args=[],
keywords=[]),
cause=None)
]
func.decorator_list = []
if isinstance(func.returns, (ast.Name, ast.Subscript)):
func.returns = update_annotation(
resource_name, func.returns, nodes, parent
)


def return_subclass(
resource_name: str, stmt: ast.ClassDef, parent: str = ""
) -> List[str]:
def parse_subclass(resource_name: str, stmt: ast.ClassDef, parent: str):
nodes = set()
nodes_to_remove = []
parent = parent if parent else resource_name
stmt.bases = [ast.Name(id=f"{parent}.{stmt.name}", ctx=ast.Load())]
for cstmt in stmt.body:
if isinstance(cstmt, ast.Expr) or (
isinstance(cstmt, ast.FunctionDef) and cstmt.name == "__init__"
):
nodes_to_remove.append(cstmt)
elif isinstance(cstmt, ast.AnnAssign):
nodes.add(cstmt.target.id)
nodes_to_remove.append(cstmt)
elif isinstance(cstmt, ast.ClassDef):
parse_subclass(resource_name, cstmt, stmt.bases[0].id)
jckras marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(cstmt, ast.AsyncFunctionDef):
replace_async_func(resource_name, cstmt, nodes, stmt.bases[0].id)
for node in nodes_to_remove:
stmt.body.remove(node)
if stmt.body == []:
stmt.body = [ast.Pass()]

parse_subclass(resource_name, stmt, parent)
return '\n'.join(
[' ' + line for line in ast.unparse(stmt).splitlines()]
)


def main(
resource_type: str,
resource_subtype: str,
Expand All @@ -22,48 +99,37 @@ def main(
import isort
from slugify import slugify

module_name = f"viam.{resource_type}s.{resource_subtype}.{resource_subtype}"
module_name = (
f"viam.{resource_type}s.{resource_subtype}.{resource_subtype}"
)
module = import_module(module_name)
if resource_subtype == "input":
resource_name = "Controller"
elif resource_subtype == "slam":
resource_name = "SLAM"
elif resource_subtype == "mlmodel":
resource_name = "MLModel"
else:
resource_name = "".join(word.capitalize() for word in resource_subtype.split("_"))

imports = []
resource_name = {
"input": "Controller", "slam": "SLAM", "mlmodel": "MLModel"
}.get(resource_subtype, "".join(word.capitalize()
for word in resource_subtype.split("_")))

imports, subclasses, abstract_methods = [], [], []
nodes = set()
modules_to_ignore = [
"abc",
"component_base",
"service_base",
"viam.resource.types",
]
abstract_methods = []
with open(module.__file__, "r") as f:
def update_annotation(annotation):
if isinstance(annotation, ast.Name) and annotation.id in nodes:
return return_attribute(resource_name, annotation.id)
elif isinstance(annotation, ast.Subscript):
annotation.slice = update_annotation(annotation.slice)
return annotation
return annotation

tree = ast.parse(f.read())
nodes = []
for stmt in tree.body:
if isinstance(stmt, ast.Import):
for imp in stmt.names:
if imp.name in modules_to_ignore:
continue
if imp.asname:
imports.append(f"import {imp.name} as {imp.asname}")
else:
imports.append(f"import {imp.name}")
elif isinstance(stmt, ast.ImportFrom):
if stmt.module in modules_to_ignore or stmt.module is None:
continue
imports.append(f"import {imp.name} as {imp.asname}"
if imp.asname else f"import {imp.name}")
elif (
isinstance(stmt, ast.ImportFrom)
and stmt.module
and stmt.module not in modules_to_ignore
):
i_strings = ", ".join(
[
(
Expand All @@ -79,26 +145,14 @@ def update_annotation(annotation):
elif isinstance(stmt, ast.ClassDef) and stmt.name == resource_name:
for cstmt in stmt.body:
if isinstance(cstmt, ast.ClassDef):
nodes.append(cstmt.name)
subclasses.append(return_subclass(resource_name, cstmt))
elif isinstance(cstmt, ast.AnnAssign):
nodes.append(cstmt.target.id)
nodes.add(cstmt.target.id)
elif isinstance(cstmt, ast.AsyncFunctionDef):
for arg in cstmt.args.args:
arg.annotation = update_annotation(arg.annotation)

cstmt.body = [
ast.Raise(
exc=ast.Call(
func=ast.Name(id='NotImplementedError', ctx=ast.Load()),
args=[],
keywords=[]),
cause=None,
)
]
cstmt.decorator_list = []
if isinstance(cstmt.returns, ast.Name) and cstmt.returns.id in nodes:
cstmt.returns = return_attribute(resource_name, cstmt.returns.id)
indented_code = '\n'.join([' ' + line for line in ast.unparse(cstmt).splitlines()])
replace_async_func(resource_name, cstmt, nodes)
indented_code = '\n'.join(
[' ' + line for line in ast.unparse(cstmt).splitlines()]
)
jckras marked this conversation as resolved.
Show resolved Hide resolved
abstract_methods.append(indented_code)

model_name_pascal = "".join(
Expand Down Expand Up @@ -158,6 +212,7 @@ def reconfigure(self, config: ComponentConfig, dependencies: Mapping[ResourceNam
return super().reconfigure(config, dependencies)

{8}
{9}


if __name__ == '__main__':
Expand All @@ -172,6 +227,7 @@ def reconfigure(self, config: ComponentConfig, dependencies: Mapping[ResourceNam
namespace,
mod_name,
model_name,
'\n\n'.join([subclass for subclass in subclasses]),
'\n\n'.join([f'{method}' for method in abstract_methods]),
)
f_name = os.path.join(mod_name, "src", "main.py")
Expand Down
Loading