diff --git a/Makefile b/Makefile index 81357f4..242ed38 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ # Helper script for decompiling some files for testing fernflower = fabric-fernflower-1.4.1+local.jar -quiltflower = quiltflower-1.7.0+local.jar +quiltflower = quiltflower-1.8.1+local.jar namedjar = 1.16.5-named.jar procyon = `curl --silent https://api.github.com/repos/mstrobel/procyon/releases/latest | jq -r .assets[0].name` diff --git a/astutil/type_parsing.go b/astutil/type_parsing.go new file mode 100644 index 0000000..e49460b --- /dev/null +++ b/astutil/type_parsing.go @@ -0,0 +1,61 @@ +package astutil + +import ( + "fmt" + "go/ast" + + sitter "github.com/smacker/go-tree-sitter" +) + +func ParseType(node *sitter.Node, source []byte) ast.Expr { + switch node.Type() { + case "integral_type": + switch node.Child(0).Type() { + case "int": + return &ast.Ident{Name: "int32"} + case "short": + return &ast.Ident{Name: "int16"} + case "long": + return &ast.Ident{Name: "int64"} + case "char": + return &ast.Ident{Name: "rune"} + case "byte": + return &ast.Ident{Name: node.Content(source)} + } + + panic(fmt.Errorf("Unknown integral type: %v", node.Child(0).Type())) + case "floating_point_type": // Can be either `float` or `double` + switch node.Child(0).Type() { + case "float": + return &ast.Ident{Name: "float32"} + case "double": + return &ast.Ident{Name: "float64"} + } + + panic(fmt.Errorf("Unknown float type: %v", node.Child(0).Type())) + case "void_type": + return &ast.Ident{} + case "boolean_type": + return &ast.Ident{Name: "bool"} + case "generic_type": + // A generic type is any type that is of the form GenericType + return &ast.Ident{Name: node.NamedChild(0).Content(source)} + case "array_type": + return &ast.ArrayType{Elt: ParseType(node.NamedChild(0), source)} + case "type_identifier": // Any reference type + switch node.Content(source) { + // Special case for strings, because in Go, these are primitive types + case "String": + return &ast.Ident{Name: "string"} + } + + return &ast.StarExpr{ + X: &ast.Ident{Name: node.Content(source)}, + } + case "scoped_type_identifier": + // This contains a reference to the type of a nested class + // Ex: LinkedList.Node + return &ast.StarExpr{X: &ast.Ident{Name: node.Content(source)}} + } + panic("Unknown type to convert: " + node.Type()) +} diff --git a/declaration.go b/declaration.go index ab2b10a..f08e963 100644 --- a/declaration.go +++ b/declaration.go @@ -3,8 +3,10 @@ package main import ( "go/ast" "go/token" - "strings" + "github.com/NickyBoy89/java2go/nodeutil" + "github.com/NickyBoy89/java2go/symbol" + log "github.com/sirupsen/logrus" sitter "github.com/smacker/go-tree-sitter" ) @@ -17,47 +19,29 @@ func ParseDecls(node *sitter.Node, source []byte, ctx Ctx) []ast.Decl { //"superclass" //"interfaces" - // All the declarations for the class + // The declarations and fields for the class declarations := []ast.Decl{} + fields := &ast.FieldList{} // Global variables globalVariables := &ast.GenDecl{Tok: token.VAR} - // Other declarations - fields := &ast.FieldList{} - - var public bool - - if node.NamedChild(0).Type() == "modifiers" { - for _, modifier := range UnnamedChildren(node.NamedChild(0)) { - if modifier.Type() == "public" { - public = true - } - } - } - - if public { - ctx.className = ToPublic(node.ChildByFieldName("name").Content(source)) - } else { - ctx.className = ToPrivate(node.ChildByFieldName("name").Content(source)) - } + ctx.className = ctx.currentFile.FindClass(node.ChildByFieldName("name").Content(source)).Name // First, look through the class's body for field declarations - for _, child := range Children(node.ChildByFieldName("body")) { + for _, child := range nodeutil.NamedChildrenOf(node.ChildByFieldName("body")) { if child.Type() == "field_declaration" { - var publicField, staticField bool + var staticField bool comments := []*ast.Comment{} // Handle any modifiers that the field might have if child.NamedChild(0).Type() == "modifiers" { - for _, modifier := range UnnamedChildren(child.NamedChild(0)) { + for _, modifier := range nodeutil.UnnamedChildrenOf(child.NamedChild(0)) { switch modifier.Type() { case "static": staticField = true - case "public": - publicField = true case "marker_annotation", "annotation": comments = append(comments, &ast.Comment{Text: "//" + modifier.Content(source)}) if _, in := excludedAnnotations[modifier.Content(source)]; in { @@ -68,91 +52,78 @@ func ParseDecls(node *sitter.Node, source []byte, ctx Ctx) []ast.Decl { } } - // Parse the field declaration - // The field can either be a `Field`, or a `ValueSpec` if it was assigned to a value - field := ParseNode(child, source, ctx) - - if valueField, hasValue := field.(*ast.ValueSpec); hasValue { - if len(comments) > 0 { - valueField.Doc = &ast.CommentGroup{List: comments} - } - - if staticField { - // Add the name of the current class to scope the variable to the current class - valueField.Names[0].Name = ctx.className + valueField.Names[0].Name + // TODO: If a field is initialized to a value, that value is discarded - if publicField { - valueField.Names[0] = CapitalizeIdent(valueField.Names[0]) - } else { - valueField.Names[0] = LowercaseIdent(valueField.Names[0]) - } + field := &ast.Field{} + if len(comments) > 0 { + field.Doc = &ast.CommentGroup{List: comments} + } - globalVariables.Specs = append(globalVariables.Specs, valueField) - } else { - // TODO: If a variable is not static and it is initialized to - // a value, the value is thrown away - fields.List = append(fields.List, &ast.Field{Names: valueField.Names, Type: valueField.Type}) - } - } else { - if len(comments) > 0 { - field.(*ast.Field).Doc = &ast.CommentGroup{List: comments} - } + fieldName := child.ChildByFieldName("declarator").ChildByFieldName("name").Content(source) - if staticField { - // Add the name of the current class to scope the variable to the current class - field.(*ast.Field).Names[0].Name = ctx.className + field.(*ast.Field).Names[0].Name + fieldDef := ctx.currentClass.FindField().ByOriginalName(fieldName)[0] - if publicField { - field.(*ast.Field).Names[0] = CapitalizeIdent(field.(*ast.Field).Names[0]) - } else { - field.(*ast.Field).Names[0] = LowercaseIdent(field.(*ast.Field).Names[0]) - } + field.Names, field.Type = []*ast.Ident{&ast.Ident{Name: fieldDef.Name}}, &ast.Ident{Name: fieldDef.Type} - globalVariables.Specs = append(globalVariables.Specs, &ast.ValueSpec{Names: field.(*ast.Field).Names, Type: field.(*ast.Field).Type}) - } else { - fields.List = append(fields.List, field.(*ast.Field)) - } + if staticField { + globalVariables.Specs = append(globalVariables.Specs, &ast.ValueSpec{Names: field.Names, Type: field.Type}) + } else { + fields.List = append(fields.List, field) } } } - // Add everything into the declarations - + // Add the global variables if len(globalVariables.Specs) > 0 { declarations = append(declarations, globalVariables) } + // Add any type paramters defined in the class if node.ChildByFieldName("type_parameters") != nil { declarations = append(declarations, ParseDecls(node.ChildByFieldName("type_parameters"), source, ctx)...) } + // Add the struct for the class declarations = append(declarations, GenStruct(ctx.className, fields)) + // Add all the declarations that appear in the class declarations = append(declarations, ParseDecls(node.ChildByFieldName("body"), source, ctx)...) return declarations - case "class_body": + case "class_body": // The body of the currently parsed class decls := []ast.Decl{} - var child *sitter.Node - for i := 0; i < int(node.NamedChildCount()); i++ { - child = node.NamedChild(i) + + // To switch to parsing the subclasses of a class, since we assume that + // all the class's subclass definitions are in-order, if we find some number + // of subclasses in a class, we can refer to them by index + var subclassIndex int + + for _, child := range nodeutil.NamedChildrenOf(node) { switch child.Type() { // Skip fields and comments case "field_declaration", "comment": case "constructor_declaration", "method_declaration", "static_initializer": d := ParseDecl(child, source, ctx) - if _, bad := d.(*ast.BadDecl); !bad { + // If the declaration is bad, skip it + _, bad := d.(*ast.BadDecl) + if !bad { decls = append(decls, d) } + + // Subclasses case "class_declaration", "interface_declaration", "enum_declaration": - decls = append(decls, ParseDecls(child, source, ctx)...) + newCtx := ctx.Clone() + newCtx.currentClass = ctx.currentClass.Subclasses[subclassIndex] + subclassIndex++ + decls = append(decls, ParseDecls(child, source, newCtx)...) } } + return decls case "interface_body": methods := &ast.FieldList{} - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { if c.Type() == "method_declaration" { parsedMethod := ParseNode(c, source, ctx).(*ast.Field) // If the method was ignored with an annotation, it will return a blank @@ -165,44 +136,23 @@ func ParseDecls(node *sitter.Node, source []byte, ctx Ctx) []ast.Decl { return []ast.Decl{GenInterface(ctx.className, methods)} case "interface_declaration": - decls := []ast.Decl{} - - for _, c := range Children(node) { - switch c.Type() { - case "modifiers": - case "identifier": - ctx.className = c.Content(source) - case "interface_body": - decls = ParseDecls(c, source, ctx) - } - } + ctx.className = ctx.currentFile.FindClass(node.ChildByFieldName("name").Content(source)).Name - return decls + return ParseDecls(node.ChildByFieldName("body"), source, ctx) case "enum_declaration": // An enum is treated as both a struct, and a list of values that define // the states that the enum can be in - //modifiers := ParseNode(node.NamedChild(0), source, ctx) - - ctx.className = node.NamedChild(1).Content(source) - - for _, item := range Children(node.NamedChild(2)) { - switch item.Type() { - case "enum_body_declarations": - for _, bodyDecl := range Children(item) { - _ = bodyDecl - } - } - } + ctx.className = ctx.currentFile.FindClass(node.ChildByFieldName("name").Content(source)).Name - // TODO: Fix this to handle an enum correctly - //decls := []ast.Decl{GenStruct(ctx.className, fields)} + // TODO: Handle an enum correctly + //return ParseDecls(node.ChildByFieldName("body"), source, ctx) return []ast.Decl{} case "type_parameters": var declarations []ast.Decl // A list of generic type parameters - for _, param := range Children(node) { + for _, param := range nodeutil.NamedChildrenOf(node) { switch param.Type() { case "type_parameter": declarations = append(declarations, GenTypeInterface(param.NamedChild(0).Content(source), []string{"any"})) @@ -219,137 +169,160 @@ func ParseDecls(node *sitter.Node, source []byte, ctx Ctx) []ast.Decl { func ParseDecl(node *sitter.Node, source []byte, ctx Ctx) ast.Decl { switch node.Type() { case "constructor_declaration": - var body *ast.BlockStmt - var name *ast.Ident - var params *ast.FieldList - - for _, c := range Children(node) { - switch c.Type() { - case "identifier": - name = ParseExpr(c, source, ctx).(*ast.Ident) - case "formal_parameters": - params = ParseNode(c, source, ctx).(*ast.FieldList) - case "constructor_body": - body = ParseStmt(c, source, ctx).(*ast.BlockStmt) + paramNode := node.ChildByFieldName("parameters") + + constructorName := node.ChildByFieldName("name").Content(source) + + comparison := func(d *symbol.Definition) bool { + // The names must match + if constructorName != d.OriginalName { + return false + } + + // Size of parameters must match + if int(paramNode.NamedChildCount()) != len(d.Parameters) { + return false + } + + // Go through the types and check to see if they differ + for index, param := range nodeutil.NamedChildrenOf(paramNode) { + var paramType string + if param.Type() == "spread_parameter" { + paramType = param.NamedChild(0).Content(source) + } else { + paramType = param.ChildByFieldName("type").Content(source) + } + if paramType != d.Parameters[index].OriginalType { + return false + } } + + return true } - // Create the object to construct in the constructor - body.List = append([]ast.Stmt{&ast.AssignStmt{ - Lhs: []ast.Expr{&ast.Ident{Name: ShortName(ctx.className)}}, - Tok: token.DEFINE, - Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.Ident{Name: "new"}, Args: []ast.Expr{&ast.Ident{Name: ctx.className}}}}, - }}, body.List...) - // Return the created object + // Search through the current class for the constructor, which is simply labeled as a method + ctx.localScope = ctx.currentClass.FindMethod().By(comparison)[0] + + body := ParseStmt(node.ChildByFieldName("body"), source, ctx).(*ast.BlockStmt) + + body.List = append([]ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{&ast.Ident{Name: ShortName(ctx.className)}}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.CallExpr{Fun: &ast.Ident{Name: "new"}, Args: []ast.Expr{&ast.Ident{Name: ctx.className}}}}, + }, + }, body.List...) + body.List = append(body.List, &ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: ShortName(ctx.className)}}}) return &ast.FuncDecl{ - Name: &ast.Ident{Name: "New" + name.Name}, + Name: &ast.Ident{Name: ctx.localScope.Name}, Type: &ast.FuncType{ - Params: params, + Params: ParseNode(node.ChildByFieldName("parameters"), source, ctx).(*ast.FieldList), Results: &ast.FieldList{List: []*ast.Field{&ast.Field{ - Type: &ast.StarExpr{ - X: name, - }, + Type: &ast.Ident{Name: ctx.localScope.Type}, }}}, }, Body: body, } case "method_declaration": - var public, static bool - - // The return type comes as the second node, after the modifiers - // however, if the method is generic, this gets pushed down one - returnTypeIndex := 1 - if node.NamedChild(1).Type() == "type_parameters" { - returnTypeIndex++ - } - - returnType := ParseExpr(node.NamedChild(returnTypeIndex), source, ctx) - - var methodName *ast.Ident - - var params *ast.FieldList + var static bool // Store the annotations as comments on the method comments := []*ast.Comment{} - for _, c := range Children(node) { - switch c.Type() { - case "modifiers": - for _, mod := range UnnamedChildren(c) { - switch mod.Type() { - case "public": - public = true - case "static": - static = true - case "abstract": - // TODO: Handle abstract methods correctly + if node.NamedChild(0).Type() == "modifiers" { + for _, modifier := range nodeutil.UnnamedChildrenOf(node.NamedChild(0)) { + switch modifier.Type() { + case "static": + static = true + case "abstract": + log.Warn("Unhandled abstract class") + // TODO: Handle abstract methods correctly + return &ast.BadDecl{} + case "marker_annotation", "annotation": + comments = append(comments, &ast.Comment{Text: "//" + modifier.Content(source)}) + // If the annotation was on the list of ignored annotations, don't + // parse the method + if _, in := excludedAnnotations[modifier.Content(source)]; in { return &ast.BadDecl{} - case "marker_annotation", "annotation": - comments = append(comments, &ast.Comment{Text: "//" + mod.Content(source)}) - // If the annotation was on the list of ignored annotations, don't - // parse the method - if _, in := excludedAnnotations[mod.Content(source)]; in { - return &ast.BadDecl{} - } } } - case "type_parameters": // For generic types - case "formal_parameters": - params = ParseNode(c, source, ctx).(*ast.FieldList) - case "identifier": - if returnType == nil { - continue - } - // The next two identifiers determine the return type and name of the method - if public { - methodName = CapitalizeIdent(ParseExpr(c, source, ctx).(*ast.Ident)) - } else { - methodName = LowercaseIdent(ParseExpr(c, source, ctx).(*ast.Ident)) - } } } - var methodRecv *ast.FieldList + var receiver *ast.FieldList - // If the method is not static, define it as a struct's method + // If a function is non-static, it has a method receiver if !static { - methodRecv = &ast.FieldList{List: []*ast.Field{ - &ast.Field{ - Names: []*ast.Ident{&ast.Ident{Name: ShortName(ctx.className)}}, - Type: &ast.StarExpr{X: &ast.Ident{Name: ctx.className}}, + receiver = &ast.FieldList{ + List: []*ast.Field{ + &ast.Field{ + Names: []*ast.Ident{&ast.Ident{Name: ShortName(ctx.className)}}, + Type: &ast.StarExpr{X: &ast.Ident{Name: ctx.className}}, + }, }, - }} + } } - // If the methodName is nil, then the printer will panic - if methodName == nil { - panic("Method's name is nil") + methodName := ParseExpr(node.ChildByFieldName("name"), source, ctx).(*ast.Ident) + + methodParameters := node.ChildByFieldName("parameters") + + // Find the declaration for the method that we are defining + + // Find a method that is more or less exactly the same + comparison := func(d *symbol.Definition) bool { + // Throw out any methods that aren't named the same + if d.OriginalName != methodName.Name { + return false + } + + // Now, even though the method might have the same name, it could be overloaded, + // so we have to check the parameters as well + + // Number of parameters are not the same, invalid + if len(d.Parameters) != int(methodParameters.NamedChildCount()) { + return false + } + + // Go through the types and check to see if they differ + for index, param := range nodeutil.NamedChildrenOf(methodParameters) { + var paramType string + if param.Type() == "spread_parameter" { + paramType = param.NamedChild(0).Content(source) + } else { + paramType = param.ChildByFieldName("type").Content(source) + } + if d.Parameters[index].OriginalType != paramType { + return false + } + } + + // We found the correct method + return true } - method := &ast.FuncDecl{ - Doc: &ast.CommentGroup{List: comments}, - Name: methodName, - Recv: methodRecv, - Type: &ast.FuncType{ - Params: params, - Results: &ast.FieldList{List: []*ast.Field{ - &ast.Field{Type: returnType}, - }}, - }, - Body: ParseStmt(node.NamedChild(int(node.NamedChildCount()-1)), source, ctx).(*ast.BlockStmt), + methodDefinition := ctx.currentClass.FindMethod().By(comparison) + + // No definition was found + if len(methodDefinition) == 0 { + log.WithFields(log.Fields{ + "methodName": methodName.Name, + }).Panic("No matching definition found for method") } - // Special case for the main method, since this should always be lowercase, - // and per java rules, have an array of args defined with it - if strings.ToLower(methodName.Name) == "main" { - methodName.Name = "main" - // Remove all of its parameters - method.Type.Params = nil - // Add a new variable for the args - // args := os.Args - method.Body.List = append([]ast.Stmt{ + ctx.localScope = methodDefinition[0] + + body := ParseStmt(node.ChildByFieldName("body"), source, ctx).(*ast.BlockStmt) + + params := ParseNode(node.ChildByFieldName("parameters"), source, ctx).(*ast.FieldList) + + // Special case for the main method, because in Java, this method has the + // command line args passed in as a parameter + if methodName.Name == "main" { + params = nil + body.List = append([]ast.Stmt{ &ast.AssignStmt{ Lhs: []ast.Expr{&ast.Ident{Name: "args"}}, Tok: token.DEFINE, @@ -360,11 +333,27 @@ func ParseDecl(node *sitter.Node, source []byte, ctx Ctx) ast.Decl { }, }, }, - }, method.Body.List...) + }, body.List...) } - return method + return &ast.FuncDecl{ + Doc: &ast.CommentGroup{List: comments}, + Name: &ast.Ident{Name: ctx.localScope.Name}, + Recv: receiver, + Type: &ast.FuncType{ + Params: params, + Results: &ast.FieldList{ + List: []*ast.Field{ + &ast.Field{Type: &ast.Ident{Name: ctx.localScope.Type}}, + }, + }, + }, + Body: body, + } case "static_initializer": + + ctx.localScope = &symbol.Definition{} + // A block of `static`, which is run before the main function return &ast.FuncDecl{ Name: &ast.Ident{Name: "init"}, diff --git a/dependency_tree.go b/dependency_tree.go deleted file mode 100644 index 0130bed..0000000 --- a/dependency_tree.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "strings" - - sitter "github.com/smacker/go-tree-sitter" -) - -type ClassFile struct { - Name string - Package *PackageScope - Imports []*PackageScope -} - -// PackageScope is the package declaration for a single file -// it contains the name of the package (ex: "util"), as well as a pointer to -// another scope, which contains the rest of the scope (ex: "com.example") -// When the scope variable is `nil`, then you are at the root of the parent -type PackageScope struct { - Scope []string -} - -func (ps *PackageScope) String() string { - var total strings.Builder - for ind, item := range ps.Scope { - total.WriteString(item) - if ind < len(ps.Scope)-1 { - total.WriteRune('.') - } - } - return total.String() -} - -// ParseScope takes a identifier from an import node, and the source code, and -// parses it as a `PackageScope` type -func ParseScope(node *sitter.Node, source []byte) *PackageScope { - pack := &PackageScope{} - // A `scoped_identifier` contains two items, one for the scope, and the other - // for the name of the current package - - if node.Type() != "package_declaration" { - pack.Scope = []string{node.NamedChild(1).Content(source)} - } - - scope := node.NamedChild(0) - for scope.Type() == "scoped_identifier" { - pack.Scope = append(pack.Scope, scope.NamedChild(1).Content(source)) - scope = scope.NamedChild(0) - } - - pack.Scope = append(pack.Scope, scope.Content(source)) - - // Flip the order of the scope, because it is the wrong direction - for ind := 0; ind < len(pack.Scope)/2; ind++ { - rev := len(pack.Scope) - 1 - ind - pack.Scope[ind], pack.Scope[rev] = pack.Scope[rev], pack.Scope[ind] - } - - return pack -} - -// ExtractImports takes in a tree-sitter node and the source code, returning the -// parsed source file's imports, and other package-related data -func ExtractImports(node *sitter.Node, source []byte) *ClassFile { - class := &ClassFile{} - - // If this node has a package declaration, add it to the current class - if node.Type() == "package_declaration" { - class.Package = ParseScope(node, source) - } - - // If the node is an import node, return that as a single import - if node.Type() == "import_declaration" { - class.Imports = append(class.Imports, ParseScope(node.NamedChild(0), source)) - return class - } - - // Go through the children of the node to find everything else - for _, child := range Children(node) { - // Extract the node of the class being parsed - if node.Type() == "class_declaration" && child.Type() == "identifier" { - class.Name = child.Content(source) - } - - // Go through the children of the current node to find the imports - other := ExtractImports(child, source) - if len(other.Imports) > 0 { - class.Imports = append(class.Imports, other.Imports...) - } - // If the class name is unknown, and it has been found in one of the - // children, populate it - if class.Name == "" { - class.Name = other.Name - } - - if class.Package == nil { - class.Package = other.Package - } - } - - return class -} diff --git a/expression.go b/expression.go index d71f0bf..6a8b591 100644 --- a/expression.go +++ b/expression.go @@ -5,18 +5,14 @@ import ( "go/ast" "go/token" + "github.com/NickyBoy89/java2go/astutil" + "github.com/NickyBoy89/java2go/nodeutil" + "github.com/NickyBoy89/java2go/symbol" log "github.com/sirupsen/logrus" sitter "github.com/smacker/go-tree-sitter" ) func ParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { - if expr := TryParseExpr(node, source, ctx); expr != nil { - return expr - } - panic(fmt.Errorf("Unhandled expr type: %v", node.Type())) -} - -func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { switch node.Type() { case "ERROR": log.WithFields(log.Fields{ @@ -57,10 +53,6 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { ParseExpr(node.Child(2), source, ctx), }, } - case "scoped_type_identifier": - // This contains a reference to the type of a nested class - // Ex: LinkedList.Node - return &ast.StarExpr{X: &ast.Ident{Name: node.Content(source)}} case "super": return &ast.BadExpr{} case "lambda_expression": @@ -70,39 +62,45 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { var lambdaBody *ast.BlockStmt - if expr := TryParseExpr(node.NamedChild(1), source, ctx); expr != nil { - // The body can be a single expression + var lambdaParameters *ast.FieldList + + bodyNode := node.ChildByFieldName("body") + + switch bodyNode.Type() { + case "block": + lambdaBody = ParseStmt(bodyNode, source, ctx).(*ast.BlockStmt) + default: + // Lambdas can be called inline without a block expression lambdaBody = &ast.BlockStmt{ List: []ast.Stmt{ &ast.ExprStmt{ - X: ParseExpr(node.NamedChild(1), source, ctx), + X: ParseExpr(bodyNode, source, ctx), }, }, } - } else { - lambdaBody = ParseStmt(node.NamedChild(1), source, ctx).(*ast.BlockStmt) } - switch node.NamedChild(0).Type() { + paramNode := node.ChildByFieldName("parameters") + + switch paramNode.Type() { case "inferred_parameters", "formal_parameters": - return &ast.FuncLit{ - Type: &ast.FuncType{ - Params: ParseNode(node.NamedChild(0), source, ctx).(*ast.FieldList), + lambdaParameters = ParseNode(paramNode, source, ctx).(*ast.FieldList) + default: + // If we can't identify the types of the parameters, then just set their + // types to any + lambdaParameters = &ast.FieldList{ + List: []*ast.Field{ + &ast.Field{ + Names: []*ast.Ident{ParseExpr(paramNode, source, ctx).(*ast.Ident)}, + Type: &ast.Ident{Name: "any"}, + }, }, - Body: lambdaBody, } } return &ast.FuncLit{ Type: &ast.FuncType{ - Params: &ast.FieldList{ - List: []*ast.Field{ - &ast.Field{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(0), source, ctx).(*ast.Ident)}, - Type: &ast.Ident{Name: "interface{}"}, - }, - }, - }, + Params: lambdaParameters, }, Body: lambdaBody, } @@ -125,11 +123,18 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { case "array_initializer": // A literal that initilzes an array, such as `{1, 2, 3}` items := []ast.Expr{} - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { items = append(items, ParseExpr(c, source, ctx)) } + + // If there wasn't a type for the array specified, then use the one that has been defined + if _, ok := ctx.lastType.(*ast.ArrayType); ctx.lastType != nil && ok { + return &ast.CompositeLit{ + Type: ctx.lastType.(*ast.ArrayType), + Elts: items, + } + } return &ast.CompositeLit{ - Type: ctx.lastType.(*ast.ArrayType), Elts: items, } case "method_invocation": @@ -151,54 +156,89 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { case "object_creation_expression": // This is called when anything is created with a constructor - // Usually, this is called in this format: - // * The name of the type, this can either be an `identifier` or `generic_type` - // * An `argument_list` for the constructor's arguments + objectType := node.ChildByFieldName("type") + + // A object can also be created with this format: + // parentClass.new NestedClass() + if !node.NamedChild(0).Equal(objectType) { + } - // But, when creating a new inner class from an outer class, it can use this format: - // outerClass.new InnerClass() + // Get all the arguments, and look up their types + objectArguments := node.ChildByFieldName("arguments") + arguments := make([]ast.Expr, objectArguments.NamedChildCount()) + argumentTypes := make([]string, objectArguments.NamedChildCount()) + for ind, argument := range nodeutil.NamedChildrenOf(objectArguments) { + arguments[ind] = ParseExpr(argument, source, ctx) - // The name of the function will always be the last identifier - var functionNameInd int - for ind, c := range Children(node) { - if c.Type() == "type_identifier" { - functionNameInd = ind + // Look up each argument and find its type + if argument.Type() != "identifier" { + argumentTypes[ind] = symbol.TypeOfLiteral(argument, source) + } else { + if localDef := ctx.localScope.FindVariable(argument.Content(source)); localDef != nil { + argumentTypes[ind] = localDef.OriginalType + // Otherwise, a variable may exist as a global variable + } else if def := ctx.currentFile.FindField().ByOriginalName(argument.Content(source)); len(def) > 0 { + argumentTypes[ind] = def[0].OriginalType + } } } - var functionName string - parsed := ParseExpr(node.NamedChild(functionNameInd), source, ctx) - switch parsed.(type) { - case *ast.Ident: - functionName = parsed.(*ast.Ident).Name - case *ast.StarExpr: - functionName = parsed.(*ast.StarExpr).X.(*ast.Ident).Name + var constructor *symbol.Definition + // Find the respective constructor, and call it + if objectType.Type() == "generic_type" { + constructor = ctx.currentClass.FindMethodByName(objectType.NamedChild(0).Content(source), argumentTypes) + } else { + constructor = ctx.currentClass.FindMethodByName(objectType.Content(source), argumentTypes) + } + + if constructor != nil { + return &ast.CallExpr{ + Fun: &ast.Ident{Name: constructor.Name}, + Args: arguments, + } } + // It is also possible that a constructor could be unresolved, so we handle + // this by calling the type of the type + "Construct" at the beginning return &ast.CallExpr{ - Fun: &ast.Ident{Name: "New" + functionName}, - Args: ParseNode(node.NamedChild(functionNameInd+1), source, ctx).([]ast.Expr), + Fun: &ast.Ident{Name: "Construct" + objectType.Content(source)}, + Args: arguments, } case "array_creation_expression": - // The type of the array - arrayType := ParseExpr(node.NamedChild(0), source, ctx) - // The dimensions of the array, which Golang only supports defining one at - // a time with the use of the builtin `make` - dimensions := []ast.Expr{&ast.ArrayType{Elt: arrayType}} - for _, c := range Children(node)[1:] { - if c.Type() == "dimensions_expr" { - dimensions = append(dimensions, ParseExpr(c, source, ctx)) + arguments := []ast.Expr{&ast.ArrayType{Elt: astutil.ParseType(node.ChildByFieldName("type"), source)}} + + for _, child := range nodeutil.NamedChildrenOf(node) { + if child.Type() == "dimensions_expr" { + arguments = append(arguments, ParseExpr(child, source, ctx)) + } + } + + var methodName string + switch len(arguments) - 1 { + case 0: + expr := ParseExpr(node.ChildByFieldName("value"), source, ctx).(*ast.CompositeLit) + expr.Type = &ast.ArrayType{ + Elt: astutil.ParseType(node.ChildByFieldName("type"), source), } + return expr + case 1: + methodName = "make" + case 2: + methodName = "MultiDimensionArray" + case 3: + methodName = "MultiDimensionArray3" + default: + panic("Unimplemented number of dimensions in array initializer") } return &ast.CallExpr{ - Fun: &ast.Ident{Name: "make"}, - Args: dimensions, + Fun: &ast.Ident{Name: methodName}, + Args: arguments, } case "instanceof_expression": return &ast.BadExpr{} case "dimensions_expr": - return &ast.Ident{Name: node.NamedChild(0).Content(source)} + return ParseExpr(node.NamedChild(0), source, ctx) case "binary_expression": if node.Child(1).Content(source) == ">>>" { return &ast.CallExpr{ @@ -225,7 +265,7 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { // condition, and returns one of the two values, depending on the condition args := []ast.Expr{} - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { args = append(args, ParseExpr(c, source, ctx)) } return &ast.CallExpr{ @@ -233,14 +273,33 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { Args: args, } case "cast_expression": + // TODO: This probably should be a cast function, instead of an assertion return &ast.TypeAssertExpr{ X: ParseExpr(node.NamedChild(1), source, ctx), - Type: ParseExpr(node.NamedChild(0), source, ctx), + Type: astutil.ParseType(node.NamedChild(0), source), } case "field_access": + // X.Sel + obj := node.ChildByFieldName("object") + + if obj.Type() == "this" { + def := ctx.currentClass.FindField().ByOriginalName(node.ChildByFieldName("field").Content(source)) + if len(def) == 0 { + // TODO: This field could not be found in the current class, because it exists in the superclass + // definition for the class + def = []*symbol.Definition{&symbol.Definition{ + Name: node.ChildByFieldName("field").Content(source), + }} + } + + return &ast.SelectorExpr{ + X: ParseExpr(node.ChildByFieldName("object"), source, ctx), + Sel: &ast.Ident{Name: def[0].Name}, + } + } return &ast.SelectorExpr{ - X: ParseExpr(node.NamedChild(0), source, ctx), - Sel: ParseExpr(node.NamedChild(1), source, ctx).(*ast.Ident), + X: ParseExpr(obj, source, ctx), + Sel: ParseExpr(node.ChildByFieldName("field"), source, ctx).(*ast.Ident), } case "array_access": return &ast.IndexExpr{ @@ -253,45 +312,22 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { return &ast.Ident{Name: ShortName(ctx.className)} case "identifier": return &ast.Ident{Name: node.Content(source)} - case "integral_type": - switch node.Child(0).Type() { - case "int": - return &ast.Ident{Name: "int32"} - case "short": - return &ast.Ident{Name: "int16"} - case "long": - return &ast.Ident{Name: "int64"} - case "char": - return &ast.Ident{Name: "rune"} - case "byte": - return &ast.Ident{Name: node.Content(source)} - } - - panic(fmt.Errorf("Unknown integral type: %v", node.Child(0).Type())) - case "floating_point_type": // Can be either `float` or `double` - switch node.Child(0).Type() { - case "float": - return &ast.Ident{Name: "float32"} - case "double": - return &ast.Ident{Name: "float64"} - } - - panic(fmt.Errorf("Unknown float type: %v", node.Child(0).Type())) - case "void_type": - return &ast.Ident{} - case "boolean_type": - return &ast.Ident{Name: "bool"} - case "generic_type": - // A generic type is any type that is of the form GenericType - return &ast.Ident{Name: node.NamedChild(0).Content(source)} - case "array_type": - return &ast.ArrayType{Elt: ParseExpr(node.NamedChild(0), source, ctx)} case "type_identifier": // Any reference type switch node.Content(source) { // Special case for strings, because in Go, these are primitive types case "String": return &ast.Ident{Name: "string"} } + + if ctx.currentFile != nil { + // Look for the class locally first + if localClass := ctx.currentFile.FindClass(node.Content(source)); localClass != nil { + return &ast.StarExpr{ + X: &ast.Ident{Name: localClass.Name}, + } + } + } + return &ast.StarExpr{ X: &ast.Ident{Name: node.Content(source)}, } @@ -323,18 +359,5 @@ func TryParseExpr(node *sitter.Node, source []byte, ctx Ctx) ast.Expr { case "true", "false": return &ast.Ident{Name: node.Content(source)} } - return nil -} - -func ParseExprs(node *sitter.Node, source []byte, ctx Ctx) []ast.Expr { - if exprs := TryParseExprs(node, source, ctx); exprs != nil { - return exprs - } - panic(fmt.Errorf("Unhandled type for exprs: %v", node.Type())) -} - -func TryParseExprs(node *sitter.Node, source []byte, ctx Ctx) []ast.Expr { - switch node.Type() { - } - return nil + panic("Unhandled expression: " + node.Type()) } diff --git a/go.mod b/go.mod index b558b3c..560a42b 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,11 @@ go 1.18 require ( github.com/sirupsen/logrus v1.8.1 - github.com/smacker/go-tree-sitter v0.0.0-20220323032108-9170be4a682c + github.com/smacker/go-tree-sitter v0.0.0-20220421092837-ec55f7cfeaf4 golang.org/x/exp v0.0.0-20220312040426-20fd27f61765 ) require ( - golang.org/x/sys v0.0.0-20220403205710-6acee93ad0eb // indirect + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index fa7bbf4..f19990f 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smacker/go-tree-sitter v0.0.0-20220323032108-9170be4a682c h1:DyKhBbcwOEbQ5zJXqsIU631DZgdtnUJtUYOA1PxF5QM= github.com/smacker/go-tree-sitter v0.0.0-20220323032108-9170be4a682c/go.mod h1:EiUuVMUfLQj8Sul+S8aKWJwQy7FRYnJCO2EWzf8F5hk= +github.com/smacker/go-tree-sitter v0.0.0-20220421092837-ec55f7cfeaf4 h1:UFOHRX5nrxNCVORhicjy31nzSVt9rEjf/YRcx2Dc3MM= +github.com/smacker/go-tree-sitter v0.0.0-20220421092837-ec55f7cfeaf4/go.mod h1:EiUuVMUfLQj8Sul+S8aKWJwQy7FRYnJCO2EWzf8F5hk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= @@ -16,6 +18,8 @@ golang.org/x/exp v0.0.0-20220312040426-20fd27f61765/go.mod h1:lgLbSvA5ygNOMpwM/9 golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220403205710-6acee93ad0eb h1:PVGECzEo9Y3uOidtkHGdd347NjLtITfJFO9BxFpmRoo= golang.org/x/sys v0.0.0-20220403205710-6acee93ad0eb/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/java2go.go b/java2go.go index 58dbea1..611b89c 100644 --- a/java2go.go +++ b/java2go.go @@ -9,14 +9,14 @@ import ( "io" "io/fs" "os" + "path" "path/filepath" "runtime" "runtime/pprof" "strings" "sync" - stdpath "path" - + "github.com/NickyBoy89/java2go/symbol" log "github.com/sirupsen/logrus" sitter "github.com/smacker/go-tree-sitter" "github.com/smacker/go-tree-sitter/java" @@ -24,27 +24,37 @@ import ( var ( // Stores a global list of Java annotations to exclude from the generated code - excludedAnnotations = make(map[string]struct{}) + excludedAnnotations = make(map[string]bool) ) -func main() { - parser := sitter.NewParser() - parser.SetLanguage(java.GetLanguage()) - - writeFlag := flag.Bool("w", false, "Whether to write the files to disk instead of stdout") - quiet := flag.Bool("q", false, "Don't write to stdout on successful parse") - astFlag := flag.Bool("ast", false, "Print out go's pretty-printed ast, instead of source code") - syncFlag := flag.Bool("sync", false, "Parse the files sequentially, instead of multi-threaded") - outDirFlag := flag.String("outDir", ".", "Specify a directory for the generated files") - - var cpuprofile = flag.String("cpuprofile", "", "write cpu profile to `file`") +type SourceFile struct { + Name string + Source []byte + Ast *sitter.Node + Symbols *symbol.FileScope +} - excludeAnnotationsFlag := flag.String("exclude-annotations", "", "A comma-separated list of annotations to exclude from the final code generation") +func main() { + var writeFiles, quiet, displayAST, symbolAware, parseFilesSynchronously bool + var outputDirectory, ignoredAnnotations, cpuProfile string + + flag.BoolVar(&writeFiles, "w", false, "Whether to write the files to disk instead of stdout") + flag.BoolVar(&quiet, "q", false, "Don't write to stdout on successful parse") + flag.BoolVar(&displayAST, "ast", false, "Print out go's pretty-printed ast, instead of source code") + flag.BoolVar(&parseFilesSynchronously, "sync", false, "Parse the files one by one, instead of in parallel") + flag.BoolVar(&symbolAware, "symbols", true, `Whether the program is aware of the symbols of the parsed code +Results in better code generation, but can be disabled for a more direct translation +or to fix crashes with the symbol handling`, + ) + flag.StringVar(&outputDirectory, "outDir", ".", "Specify a directory for the generated files") + flag.StringVar(&ignoredAnnotations, "exclude-annotations", "", "A comma-separated list of annotations to exclude from the final code generation") + + flag.StringVar(&cpuProfile, "cpuprofile", "", "write cpu profile to `file`") flag.Parse() - if *cpuprofile != "" { - f, err := os.Create(*cpuprofile) + if cpuProfile != "" { + f, err := os.Create(cpuProfile) if err != nil { log.Fatal("could not create CPU profile: ", err) } @@ -55,25 +65,38 @@ func main() { defer pprof.StopCPUProfile() } - for _, annotation := range strings.Split(*excludeAnnotationsFlag, ",") { - excludedAnnotations[annotation] = struct{}{} + for _, annotation := range strings.Split(ignoredAnnotations, ",") { + excludedAnnotations[annotation] = true } - // Sem determines the number of files parsed in parallel - sem := make(chan struct{}, runtime.NumCPU()) - // All the files to parse - fileNames := []string{} + files := []SourceFile{} - for _, file := range flag.Args() { - err := filepath.WalkDir(file, fs.WalkDirFunc(func(path string, d fs.DirEntry, err error) error { - // Only include java files - if filepath.Ext(path) == ".java" && !d.IsDir() { - fileNames = append(fileNames, path) - } + log.Info("Collecting files...") - return nil - })) + // Collect all the files and read them into memory + for _, file := range flag.Args() { + err := filepath.WalkDir(file, fs.WalkDirFunc( + func(path string, d fs.DirEntry, err error) error { + // Only include java files + if filepath.Ext(path) == ".java" && !d.IsDir() { + sourceCode, err := os.ReadFile(path) + if err != nil { + log.WithFields(log.Fields{ + "file": path, + "error": err, + }).Panic("Error reading source file") + } + + files = append(files, SourceFile{ + Name: path, + Source: sourceCode, + }) + } + + return nil + }, + )) if err != nil { log.WithFields(log.Fields{ @@ -83,50 +106,102 @@ func main() { } } - if len(fileNames) == 0 { + if len(files) == 0 { log.Warn("No files specified to convert") return } + // Parse the ASTs of all the files + + log.Info("Parsing ASTs...") + + sem := make(chan struct{}, runtime.NumCPU()) + var wg sync.WaitGroup - wg.Add(len(fileNames)) + wg.Add(len(files)) - // Start looking through the files - for _, path := range fileNames { - sourceCode, err := os.ReadFile(path) - if err != nil { - log.WithFields(log.Fields{ - "file": path, - "error": err, - }).Panic("Error reading source file") + for index := range files { + sem <- struct{}{} + + go func(index int) { + parser := sitter.NewParser() + parser.SetLanguage(java.GetLanguage()) + tree, err := parser.ParseCtx(context.Background(), nil, files[index].Source) + if err != nil { + log.WithFields(log.Fields{ + "error": err, + }).Panic("Error parsing tree-sitter AST") + } + + parser.Close() + + files[index].Ast = tree.RootNode() + + <-sem + wg.Done() + }(index) + } + + // We might still have some parsing jobs, so wait on them + wg.Wait() + + // Generate the symbol tables for the files + + if symbolAware { + log.Info("Generating symbol tables...") + + for index, file := range files { + if file.Ast.HasError() { + log.WithFields(log.Fields{ + "fileName": file.Name, + }).Warn("AST parse error in file, skipping file") + continue + } + + symbols := symbol.ParseSymbols(file.Ast, file.Source) + + files[index].Symbols = symbols + + if _, exist := symbol.GlobalScope.Packages[symbols.Package]; !exist { + symbol.GlobalScope.Packages[symbols.Package] = &symbol.PackageScope{Files: make(map[string]*symbol.FileScope)} + } + + symbol.GlobalScope.Packages[symbols.Package].AddSymbolsFromFile(symbols) } - tree, err := parser.ParseCtx(context.Background(), nil, sourceCode) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Panic("Error parsing tree-sitter AST") + // Go back through the symbol tables and fill in anything that could not be resolved + + log.Info("Resolving symbols...") + + for _, file := range files { + ResolveFile(file) } + } - n := tree.RootNode() + // Transpile the files - log.Infof("Converting file \"%s\"", path) + log.Info("Converting files...") + + for _, file := range files { + log.Infof("Converting file \"%s\"", file.Name) // Write to stdout by default var output io.Writer = os.Stdout - outputFile := path[:len(path)-len(filepath.Ext(path))] + ".go" - outputPath := *outDirFlag + "/" + outputFile + // Write to a `.go` file in the same directory + outputFile := file.Name[:len(file.Name)-len(filepath.Ext(file.Name))] + ".go" + outputPath := outputDirectory + "/" + outputFile - if *writeFlag { - if err := os.MkdirAll(stdpath.Dir(outputPath), 0755); err != nil { + if writeFiles { + err := os.MkdirAll(path.Dir(outputPath), 0755) + if err != nil { log.WithFields(log.Fields{ "error": err, "path": outputPath, }).Panic("Error creating output directory") } - // Write the output to another file + // Write the output to a file output, err = os.Create(outputPath) if err != nil { log.WithFields(log.Fields{ @@ -134,41 +209,34 @@ func main() { "file": outputPath, }).Panic("Error creating output file") } - defer output.(*os.File).Close() - } else if *quiet { + } else if quiet { + // Otherwise, throw away the output output = io.Discard } - // Acquire a semaphore - sem <- struct{}{} - - parseFunc := func() { - // Release the semaphore when done - defer func() { <-sem }() - - defer wg.Done() + // The converted AST, in Go's AST representation + var initialContext Ctx + if symbolAware { + initialContext.currentFile = file.Symbols + initialContext.currentClass = file.Symbols.BaseClass + } - parsedAst := ParseNode(n, sourceCode, Ctx{}).(ast.Node) + parsed := ParseNode(file.Ast, file.Source, initialContext).(ast.Node) - // Print the generated AST - if *astFlag { - ast.Print(token.NewFileSet(), parsedAst) - } + // Print the generated AST + if displayAST { + ast.Print(token.NewFileSet(), parsed) + } - if err := printer.Fprint(output, token.NewFileSet(), parsedAst); err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Panic("Error printing generated code") - } + // Output the parsed AST, into the source specified earlier + if err := printer.Fprint(output, token.NewFileSet(), parsed); err != nil { + log.WithFields(log.Fields{ + "error": err, + }).Panic("Error printing generated code") } - // If we don't want this to run in parallel - if *syncFlag { - parseFunc() - } else { - go parseFunc() + if writeFiles { + output.(*os.File).Close() } } - - wg.Wait() } diff --git a/nodeutil/assertions.go b/nodeutil/assertions.go new file mode 100644 index 0000000..3871ab1 --- /dev/null +++ b/nodeutil/assertions.go @@ -0,0 +1,13 @@ +package nodeutil + +import ( + "fmt" + + sitter "github.com/smacker/go-tree-sitter" +) + +func AssertTypeIs(node *sitter.Node, expectedType string) { + if node.Type() != expectedType { + panic(fmt.Sprintf("assertion failed: Type of node differs from expected: %s, got: %s", expectedType, node.Type())) + } +} diff --git a/nodeutil/node_helpers.go b/nodeutil/node_helpers.go new file mode 100644 index 0000000..1ec116a --- /dev/null +++ b/nodeutil/node_helpers.go @@ -0,0 +1,23 @@ +package nodeutil + +import sitter "github.com/smacker/go-tree-sitter" + +// NamedChildrenOf gets all named children of a given node +func NamedChildrenOf(node *sitter.Node) []*sitter.Node { + count := int(node.NamedChildCount()) + children := make([]*sitter.Node, count) + for i := 0; i < count; i++ { + children[i] = node.NamedChild(i) + } + return children +} + +// UnnamedChildrenOf gets all the named + unnamed children of a given node +func UnnamedChildrenOf(node *sitter.Node) []*sitter.Node { + count := int(node.ChildCount()) + children := make([]*sitter.Node, count) + for i := 0; i < count; i++ { + children[i] = node.Child(i) + } + return children +} diff --git a/resolve.go b/resolve.go new file mode 100644 index 0000000..89d56d6 --- /dev/null +++ b/resolve.go @@ -0,0 +1,77 @@ +package main + +import ( + "strconv" + + "github.com/NickyBoy89/java2go/symbol" +) + +func ResolveFile(file SourceFile) { + ResolveClass(file.Symbols.BaseClass, file) + for _, subclass := range file.Symbols.BaseClass.Subclasses { + ResolveClass(subclass, file) + } +} + +func ResolveClass(class *symbol.ClassScope, file SourceFile) { + // Resolve all the fields in that respective class + for _, field := range class.Fields { + + // Since a private global variable is able to be accessed in the package, it must be renamed + // to avoid conflicts with other global variables + + packageScope := symbol.GlobalScope.FindPackage(file.Symbols.Package) + + symbol.ResolveDefinition(field, file.Symbols) + + // Rename the field if its name conflits with any keyword + for i := 0; symbol.IsReserved(field.Name) || + len(packageScope.ExcludeFile(class.Class.Name).FindStaticField().ByName(field.Name)) > 0; i++ { + field.Rename(field.Name + strconv.Itoa(i)) + } + } + + // Resolve all the methods + for _, method := range class.Methods { + // Resolve the return type, as well as the body of the method + symbol.ResolveChildren(method, file.Symbols) + + // Comparison compares the method against the found method + // This tests for a method of the same name, but with different + // aspects of it, so that it can be identified as a duplicate + comparison := func(d *symbol.Definition) bool { + // The names must match, but everything else must be different + if method.Name != d.Name { + return false + } + + // Size of parameters do not match + if len(method.Parameters) != len(d.Parameters) { + return true + } + + // Go through the types and check to see if they differ + for index, param := range method.Parameters { + if param.OriginalType != d.Parameters[index].OriginalType { + return true + } + } + + // Both methods are equal, skip this method since it is likely + // the same method that we are trying to find duplicates of + return false + } + + for i := 0; symbol.IsReserved(method.Name) || len(class.FindMethod().By(comparison)) > 0; i++ { + method.Rename(method.Name + strconv.Itoa(i)) + } + // Resolve all the paramters of the method + for _, param := range method.Parameters { + symbol.ResolveDefinition(param, file.Symbols) + + for i := 0; symbol.IsReserved(param.Name); i++ { + param.Rename(param.Name + strconv.Itoa(i)) + } + } + } +} diff --git a/statement.go b/statement.go index 623946f..7a8c2a2 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,8 @@ import ( "go/ast" "go/token" + "github.com/NickyBoy89/java2go/astutil" + "github.com/NickyBoy89/java2go/nodeutil" log "github.com/sirupsen/logrus" sitter "github.com/smacker/go-tree-sitter" ) @@ -27,39 +29,38 @@ func TryParseStmt(node *sitter.Node, source []byte, ctx Ctx) ast.Stmt { case "comment": return &ast.BadStmt{} case "local_variable_declaration": - var varTypeIndex int + variableType := astutil.ParseType(node.ChildByFieldName("type"), source) + variableDeclarator := node.ChildByFieldName("declarator") - // The first child can either be modifiers e.g `final int var = 1`, or - // just the variable's type - if node.NamedChild(0).Type() == "modifiers" { - varTypeIndex = 1 - } - - // The variable declarator does not have a value (ex: int value;) - if node.NamedChild(varTypeIndex+1).NamedChildCount() == 1 { + // If a variable is being declared, but not set to a value + // Ex: `int value;` + if variableDeclarator.NamedChildCount() == 1 { return &ast.DeclStmt{ Decl: &ast.GenDecl{ Tok: token.VAR, Specs: []ast.Spec{ &ast.ValueSpec{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(varTypeIndex+1).NamedChild(0), source, ctx).(*ast.Ident)}, - Type: ParseExpr(node.NamedChild(varTypeIndex), source, ctx), + Names: []*ast.Ident{ParseExpr(variableDeclarator.ChildByFieldName("name"), source, ctx).(*ast.Ident)}, + Type: variableType, }, }, }, } } - ctx.lastType = ParseExpr(node.NamedChild(varTypeIndex), source, ctx) + ctx.lastType = variableType - declaration := ParseStmt(node.NamedChild(varTypeIndex+1), source, ctx).(*ast.AssignStmt) + declaration := ParseStmt(variableDeclarator, source, ctx).(*ast.AssignStmt) + // Now, if a variable is assigned to `null`, we can't infer its type, so + // don't throw out the type information associated with it var containsNull bool // Go through the values and see if there is a `null_literal` - for _, child := range Children(node.NamedChild(varTypeIndex + 1)) { + for _, child := range nodeutil.NamedChildrenOf(variableDeclarator) { if child.Type() == "null_literal" { containsNull = true + break } } @@ -77,7 +78,7 @@ func TryParseStmt(node *sitter.Node, source []byte, ctx Ctx) ast.Stmt { Specs: []ast.Spec{ &ast.ValueSpec{ Names: names, - Type: ParseExpr(node.NamedChild(varTypeIndex), source, ctx), + Type: variableType, Values: declaration.Rhs, }, }, @@ -146,7 +147,7 @@ func TryParseStmt(node *sitter.Node, source []byte, ctx Ctx) ast.Stmt { return &ast.ExprStmt{X: ParseExpr(node, source, ctx)} case "constructor_body", "block": body := &ast.BlockStmt{} - for _, line := range Children(node) { + for _, line := range nodeutil.NamedChildrenOf(node) { if line.Type() == "comment" { continue } @@ -272,7 +273,7 @@ func TryParseStmt(node *sitter.Node, source []byte, ctx Ctx) ast.Stmt { case "switch_block": switchBlock := &ast.BlockStmt{} var currentCase *ast.CaseClause - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { switch c.Type() { case "switch_label": // When a new switch label comes, append it to the switch block diff --git a/stdjava/common.go b/stdjava/common.go index e788e2d..3cddddc 100644 --- a/stdjava/common.go +++ b/stdjava/common.go @@ -33,7 +33,25 @@ func HashCode(s string) int { var total int n := len(s) for ind, char := range s { - total += int(char) * int(math.Pow(float64(31), float64(n - (ind+1)))) + total += int(char) * int(math.Pow(float64(31), float64(n-(ind+1)))) } return total } + +// MultiDimensionArray constructs an array with two dimensions +func MultiDimensionArray[T any](val []T, dims ...int) [][]T { + arr := make([][]T, dims[0]) + for ind := range arr { + arr[ind] = make([]T, dims[1]) + } + return arr +} + +// MultiDimensionArray3 constructs an array with three dimensions +func MultiDimensionArray3[T any](val [][]T, dims ...int) [][][]T { + arr := make([][][]T, dims[0]) + for ind := range arr { + arr[ind] = MultiDimensionArray([]T{}, dims[1:]...) + } + return arr +} diff --git a/stdjava/multidim_test.go b/stdjava/multidim_test.go new file mode 100644 index 0000000..e744950 --- /dev/null +++ b/stdjava/multidim_test.go @@ -0,0 +1,26 @@ +package stdjava + +import "testing" + +func TestSimpleGrid(t *testing.T) { + arr := MultiDimensionArray([]string{}, 3, 5) + if len(arr) != 3 { + t.Errorf("Got %d rows, expected %d", len(arr), 3) + } + if len(arr[0]) != 5 { + t.Errorf("Got %d cols, expected %d", len(arr[0]), 5) + } +} + +func TestAdvancedGrid(t *testing.T) { + arr := MultiDimensionArray3([][]int{}, 1, 2, 3) + if len(arr) != 1 { + t.Errorf("Got %d rows, expected %d", len(arr), 1) + } + if len(arr[0]) != 2 { + t.Errorf("Got %d cols, expected %d", len(arr[0]), 2) + } + if len(arr[0][0]) != 3 { + t.Errorf("Got %d third, expected %d", len(arr), 3) + } +} diff --git a/stdjava/optional.go b/stdjava/optional.go index 452bde4..aec33b2 100644 --- a/stdjava/optional.go +++ b/stdjava/optional.go @@ -6,6 +6,6 @@ type Optional[T any] struct { } // Some returns true if a value is present -func (o Optional) Some() bool { +func (o Optional[T]) Some() bool { return o.value != nil } diff --git a/symbol/class_scope.go b/symbol/class_scope.go new file mode 100644 index 0000000..d3a11b1 --- /dev/null +++ b/symbol/class_scope.go @@ -0,0 +1,145 @@ +package symbol + +// ClassScope represents a single defined class, and the declarations in it +type ClassScope struct { + // The definition for the class defined within the class + Class *Definition + // Every class that is nested within the base class + Subclasses []*ClassScope + // Any normal and static fields associated with the class + Fields []*Definition + // Methods and constructors + Methods []*Definition +} + +// FindMethod searches through the immediate class's methods find a specific method +func (cs *ClassScope) FindMethod() Finder { + cm := classMethodFinder(*cs) + return &cm +} + +// FindField searches through the immediate class's fields to find a specific field +func (cs *ClassScope) FindField() Finder { + cm := classFieldFinder(*cs) + return &cm +} + +type classMethodFinder ClassScope + +func (cm *classMethodFinder) By(criteria func(d *Definition) bool) []*Definition { + results := []*Definition{} + for _, method := range cm.Methods { + if criteria(method) { + results = append(results, method) + } + } + return results +} + +func (cm *classMethodFinder) ByName(name string) []*Definition { + return cm.By(func(d *Definition) bool { + return d.Name == name + }) +} + +func (cm *classMethodFinder) ByOriginalName(originalName string) []*Definition { + return cm.By(func(d *Definition) bool { + return d.OriginalName == originalName + }) +} + +type classFieldFinder ClassScope + +func (cm *classFieldFinder) By(criteria func(d *Definition) bool) []*Definition { + results := []*Definition{} + for _, method := range cm.Fields { + if criteria(method) { + results = append(results, method) + } + } + return results +} + +func (cm *classFieldFinder) ByName(name string) []*Definition { + return cm.By(func(d *Definition) bool { + return d.Name == name + }) +} + +func (cm *classFieldFinder) ByOriginalName(originalName string) []*Definition { + return cm.By(func(d *Definition) bool { + return d.OriginalName == originalName + }) +} + +// FindMethodByDisplayName searches for a given method by its display name +// If some ignored parameter types are specified as non-nil, it will skip over +// any function that matches these ignored parameter types exactly +func (cs *ClassScope) FindMethodByName(name string, ignoredParameterTypes []string) *Definition { + return cs.findMethodWithComparison(func(method *Definition) bool { return method.OriginalName == name }, ignoredParameterTypes) +} + +// FindMethodByDisplayName searches for a given method by its display name +// If some ignored parameter types are specified as non-nil, it will skip over +// any function that matches these ignored parameter types exactly +func (cs *ClassScope) FindMethodByDisplayName(name string, ignoredParameterTypes []string) *Definition { + return cs.findMethodWithComparison(func(method *Definition) bool { return method.Name == name }, ignoredParameterTypes) +} + +func (cs *ClassScope) findMethodWithComparison(comparison func(method *Definition) bool, ignoredParameterTypes []string) *Definition { + for _, method := range cs.Methods { + if comparison(method) { + // If no parameters were specified to ignore, then return the first match + if ignoredParameterTypes == nil { + return method + } else if len(method.Parameters) != len(ignoredParameterTypes) { // Size of parameters were not equal, instantly not equal + return method + } + + // Check the remaining paramters one-by-one + for index, parameter := range method.Parameters { + if parameter.OriginalType != ignoredParameterTypes[index] { + return method + } + } + } + } + + // Not found + return nil +} + +// FindClass searches through a class file and returns the definition for the +// found class, or nil if none was found +func (cs *ClassScope) FindClass(name string) *Definition { + if cs.Class.OriginalName == name { + return cs.Class + } + for _, subclass := range cs.Subclasses { + class := subclass.FindClass(name) + if class != nil { + return class + } + } + return nil +} + +// FindFieldByName searches for a field by its original name, and returns its definition +// or nil if none was found +func (cs *ClassScope) FindFieldByName(name string) *Definition { + for _, field := range cs.Fields { + if field.OriginalName == name { + return field + } + } + return nil +} + +func (cs *ClassScope) FindFieldByDisplayName(name string) *Definition { + for _, field := range cs.Fields { + if field.Name == name { + return field + } + } + return nil +} diff --git a/symbol/definition.go b/symbol/definition.go new file mode 100644 index 0000000..b9722e2 --- /dev/null +++ b/symbol/definition.go @@ -0,0 +1,66 @@ +package symbol + +// Definition represents the name and type of a single symbol +type Definition struct { + // The original Java name + OriginalName string + // The display name of the definition, may be different from the original name + Name string + // Original Java type of the object + OriginalType string + // Display type of the object + Type string + + // If the definition is a constructor + // This is used so that the definition handles its special naming and + // type rules correctly + Constructor bool + // If the object is a function, it has parameters + Parameters []*Definition + // Children of the declaration, if the declaration is a scope + Children []*Definition +} + +// Rename changes the display name of a definition +func (d *Definition) Rename(name string) { + d.Name = name +} + +// ParameterByName returns a parameter's definition, given its original name +func (d *Definition) ParameterByName(name string) *Definition { + for _, param := range d.Parameters { + if param.OriginalName == name { + return param + } + } + return nil +} + +// OriginalParameterTypes returns a list of the original types for all the parameters +func (d *Definition) OriginalParameterTypes() []string { + names := make([]string, len(d.Parameters)) + for ind, param := range d.Parameters { + names[ind] = param.OriginalType + } + return names +} + +// FindVariable searches a definition's immediate children and parameters +// to try and find a given variable by its original name +func (d *Definition) FindVariable(name string) *Definition { + for _, param := range d.Parameters { + if param.OriginalName == name { + return param + } + } + for _, child := range d.Children { + if child.OriginalName == name { + return child + } + } + return nil +} + +func (d Definition) IsEmpty() bool { + return d.OriginalName == "" && len(d.Children) == 0 +} diff --git a/symbol/file_scope.go b/symbol/file_scope.go new file mode 100644 index 0000000..a48a0be --- /dev/null +++ b/symbol/file_scope.go @@ -0,0 +1,60 @@ +package symbol + +// FileScope represents the scope in a single source file, that can contain one +// or more source classes +type FileScope struct { + // The global package that the file is located in + Package string + // Every external package that is imported into the file + // Formatted as map[ImportedType: full.package.path] + Imports map[string]string + // The base class that is in the file + BaseClass *ClassScope +} + +// FindClass searches through a file to find if a given class has been defined +// at its root class, or within any of the subclasses +func (fs *FileScope) FindClass(name string) *Definition { + if def := fs.BaseClass.FindClass(name); def != nil { + return def + } + for _, subclass := range fs.BaseClass.Subclasses { + if def := subclass.FindClass(name); def != nil { + return def + } + } + return nil +} + +// FindField searches through all of the classes in a file and determines if a +// field exists +func (cs *FileScope) FindField() Finder { + cm := fileFieldFinder(*cs) + return &cm +} + +type fileFieldFinder FileScope + +func findFieldsInClass(class *ClassScope, criteria func(d *Definition) bool) []*Definition { + defs := class.FindField().By(criteria) + for _, subclass := range class.Subclasses { + defs = append(defs, findFieldsInClass(subclass, criteria)...) + } + return defs +} + +func (ff *fileFieldFinder) By(criteria func(d *Definition) bool) []*Definition { + return findFieldsInClass(ff.BaseClass, criteria) +} + +func (ff *fileFieldFinder) ByName(name string) []*Definition { + return ff.By(func(d *Definition) bool { + return d.Name == name + }) +} + +func (ff *fileFieldFinder) ByOriginalName(originalName string) []*Definition { + return ff.By(func(d *Definition) bool { + return d.OriginalName == originalName + }) +} diff --git a/symbol/find.go b/symbol/find.go new file mode 100644 index 0000000..811f57f --- /dev/null +++ b/symbol/find.go @@ -0,0 +1,9 @@ +package symbol + +// Finder represents an object that can search through its contents for a given +// list of definitions that match a certian criteria +type Finder interface { + By(criteria func(d *Definition) bool) []*Definition + ByName(name string) []*Definition + ByOriginalName(originalName string) []*Definition +} diff --git a/symbol/globals.go b/symbol/globals.go new file mode 100644 index 0000000..8eb9e05 --- /dev/null +++ b/symbol/globals.go @@ -0,0 +1,25 @@ +package symbol + +var ( + // The global symbol table + GlobalScope = &GlobalSymbols{Packages: make(map[string]*PackageScope)} +) + +// A GlobalSymbols represents a global view of all the packages in the parsed source +type GlobalSymbols struct { + // Every package's path associatedd with its definition + Packages map[string]*PackageScope +} + +func (gs *GlobalSymbols) String() string { + result := "" + for packageName := range gs.Packages { + result += packageName + "\n" + } + return result +} + +// FindPackage looks up a package's path in the global scope, and returns it +func (gs *GlobalSymbols) FindPackage(name string) *PackageScope { + return gs.Packages[name] +} diff --git a/symbol/package_scope.go b/symbol/package_scope.go new file mode 100644 index 0000000..f4e77c5 --- /dev/null +++ b/symbol/package_scope.go @@ -0,0 +1,69 @@ +package symbol + +// PackageScope represents a single package, which can contain one or more files +type PackageScope struct { + // Maps the file's name to its definitions + Files map[string]*FileScope +} + +func (ps *PackageScope) ExcludeFile(excludedFileName string) *PackageScope { + newScope := &PackageScope{Files: make(map[string]*FileScope)} + for fileName, fileScope := range ps.Files { + if fileName != excludedFileName { + newScope.Files[fileName] = fileScope + } + } + return newScope +} + +func (ps *PackageScope) FindStaticField() Finder { + pf := PackageFieldFinder(*ps) + return &pf +} + +type PackageFieldFinder PackageScope + +func (pf *PackageFieldFinder) By(criteria func(d *Definition) bool) []*Definition { + results := []*Definition{} + for _, file := range pf.Files { + for _, field := range file.BaseClass.Fields { + if criteria(field) { + results = append(results, field) + } + } + } + return results +} + +func (ps *PackageFieldFinder) ByName(name string) []*Definition { + return ps.By(func(d *Definition) bool { + return d.Name == name + }) +} + +func (ps *PackageFieldFinder) ByOriginalName(originalName string) []*Definition { + return ps.By(func(d *Definition) bool { + return d.Name == originalName + }) +} + +func (ps *PackageScope) AddSymbolsFromFile(symbols *FileScope) { + ps.Files[symbols.BaseClass.Class.Name] = symbols +} + +// FindClass searches for a class in the given package and returns a scope for it +// the class may be the subclass of another class +func (ps *PackageScope) FindClass(name string) *ClassScope { + for _, fileScope := range ps.Files { + if fileScope.BaseClass.Class.OriginalName == name { + return fileScope.BaseClass + } + for _, subclass := range fileScope.BaseClass.Subclasses { + class := subclass.FindClass(name) + if class != nil { + return fileScope.BaseClass + } + } + } + return nil +} diff --git a/symbol/parsing.go b/symbol/parsing.go new file mode 100644 index 0000000..e02298d --- /dev/null +++ b/symbol/parsing.go @@ -0,0 +1,200 @@ +package symbol + +import ( + "github.com/NickyBoy89/java2go/astutil" + "github.com/NickyBoy89/java2go/nodeutil" + sitter "github.com/smacker/go-tree-sitter" +) + +// ParseSymbols generates a symbol table for a single class file. +func ParseSymbols(root *sitter.Node, source []byte) *FileScope { + var filePackage string + + var baseClass *sitter.Node + + imports := make(map[string]string) + for _, node := range nodeutil.NamedChildrenOf(root) { + switch node.Type() { + case "package_declaration": + filePackage = node.NamedChild(0).Content(source) + case "import_declaration": + importedItem := node.NamedChild(0).ChildByFieldName("name").Content(source) + importPath := node.NamedChild(0).ChildByFieldName("scope").Content(source) + + imports[importedItem] = importPath + case "class_declaration", "interface_declaration", "enum_declaration", "annotation_type_declaration": + baseClass = node + } + } + + return &FileScope{ + Imports: imports, + Package: filePackage, + BaseClass: parseClassScope(baseClass, source), + } +} + +func parseClassScope(root *sitter.Node, source []byte) *ClassScope { + var public bool + // Rename the type based on the public/static rules + if root.NamedChild(0).Type() == "modifiers" { + for _, node := range nodeutil.UnnamedChildrenOf(root.NamedChild(0)) { + if node.Type() == "public" { + public = true + } + } + } + + nodeutil.AssertTypeIs(root.ChildByFieldName("name"), "identifier") + + // Parse the main class in the file + + className := root.ChildByFieldName("name").Content(source) + scope := &ClassScope{ + Class: &Definition{ + OriginalName: className, + Name: HandleExportStatus(public, className), + }, + } + + // Parse the body of the class + + for _, node := range nodeutil.NamedChildrenOf(root.ChildByFieldName("body")) { + + switch node.Type() { + case "field_declaration": + var public bool + // Rename the type based on the public/static rules + if node.NamedChild(0).Type() == "modifiers" { + for _, modifier := range nodeutil.UnnamedChildrenOf(node.NamedChild(0)) { + if modifier.Type() == "public" { + public = true + } + } + } + + fieldNameNode := node.ChildByFieldName("declarator").ChildByFieldName("name") + + nodeutil.AssertTypeIs(fieldNameNode, "identifier") + + // TODO: Scoped type identifiers are in a format such as RemotePackage.ClassName + // To handle this, we remove the RemotePackage part, and depend on the later + // type resolution to figure things out + + // The node that the field's type comes from + typeNode := node.ChildByFieldName("type") + + // If the field is being assigned to a value + if typeNode.Type() == "scoped_type_identifier" { + typeNode = typeNode.NamedChild(int(typeNode.NamedChildCount()) - 1) + } + + // The converted name and type of the field + fieldName := fieldNameNode.Content(source) + fieldType := nodeToStr(astutil.ParseType(typeNode, source)) + + scope.Fields = append(scope.Fields, &Definition{ + Name: HandleExportStatus(public, fieldName), + OriginalName: fieldName, + Type: fieldType, + OriginalType: typeNode.Content(source), + }) + case "method_declaration", "constructor_declaration": + var public bool + // Rename the type based on the public/static rules + if node.NamedChild(0).Type() == "modifiers" { + for _, modifier := range nodeutil.UnnamedChildrenOf(node.NamedChild(0)) { + if modifier.Type() == "public" { + public = true + } + } + } + + nodeutil.AssertTypeIs(node.ChildByFieldName("name"), "identifier") + + name := node.ChildByFieldName("name").Content(source) + declaration := &Definition{ + Name: HandleExportStatus(public, name), + OriginalName: name, + Parameters: []*Definition{}, + } + + if node.Type() == "method_declaration" { + declaration.Type = nodeToStr(astutil.ParseType(node.ChildByFieldName("type"), source)) + declaration.OriginalType = node.ChildByFieldName("type").Content(source) + } else { + // A constructor declaration returns the type being constructed + + // Rename the constructor with "New" + name of type + declaration.Rename(HandleExportStatus(public, "New") + name) + declaration.Constructor = true + + // There is no original type, and the constructor returns the name of + // the new type + declaration.Type = name + } + + // Parse the parameters + + for _, parameter := range nodeutil.NamedChildrenOf(node.ChildByFieldName("parameters")) { + + var paramName string + var paramType *sitter.Node + + // If this is a spread parameter, then it will be in the format: + // (type) (variable_declarator name: (name)) + if parameter.Type() == "spread_parameter" { + paramName = parameter.NamedChild(1).ChildByFieldName("name").Content(source) + paramType = parameter.NamedChild(0) + } else { + paramName = parameter.ChildByFieldName("name").Content(source) + paramType = parameter.ChildByFieldName("type") + } + + declaration.Parameters = append(declaration.Parameters, &Definition{ + Name: paramName, + OriginalName: paramName, + Type: nodeToStr(astutil.ParseType(paramType, source)), + OriginalType: paramType.Content(source), + }) + } + + if node.ChildByFieldName("body") != nil { + methodScope := parseScope(node.ChildByFieldName("body"), source) + if !methodScope.IsEmpty() { + declaration.Children = append(declaration.Children, methodScope.Children...) + } + } + + scope.Methods = append(scope.Methods, declaration) + case "class_declaration", "interface_declaration", "enum_declaration": + other := parseClassScope(node, source) + // Any subclasses will be renamed to part of their parent class + other.Class.Rename(scope.Class.Name + other.Class.Name) + scope.Subclasses = append(scope.Subclasses, other) + } + } + + return scope +} + +func parseScope(root *sitter.Node, source []byte) *Definition { + def := &Definition{} + for _, node := range nodeutil.NamedChildrenOf(root) { + switch node.Type() { + case "local_variable_declaration": + /* + name := nodeToStr(ParseExpr(node.ChildByFieldName("declarator").ChildByFieldName("name"), source, Ctx{})) + def.Children = append(def.Children, &symbol.Definition{ + OriginalName: name, + OriginalType: node.ChildByFieldName("type").Content(source), + Type: nodeToStr(ParseExpr(node.ChildByFieldName("type"), source, Ctx{})), + Name: name, + }) + */ + case "for_statement", "enhanced_for_statement", "while_statement", "if_statement": + def.Children = append(def.Children, parseScope(node, source)) + } + } + return def +} diff --git a/symbol/parsing_helpers.go b/symbol/parsing_helpers.go new file mode 100644 index 0000000..70d4fc0 --- /dev/null +++ b/symbol/parsing_helpers.go @@ -0,0 +1,37 @@ +package symbol + +import ( + "bytes" + "go/printer" + "go/token" + "unicode" +) + +// Uppercase uppercases the first character of the given string +func Uppercase(name string) string { + return string(unicode.ToUpper(rune(name[0]))) + name[1:] +} + +// Lowercase lowercases the first character of the given string +func Lowercase(name string) string { + return string(unicode.ToLower(rune(name[0]))) + name[1:] +} + +// HandleExportStatus is a convenience method for renaming methods that may be +// either public or private, and need to be renamed +func HandleExportStatus(exported bool, name string) string { + if exported { + return Uppercase(name) + } + return Lowercase(name) +} + +// nodeToStr converts any AST node to its string representation +func nodeToStr(node any) string { + var s bytes.Buffer + err := printer.Fprint(&s, token.NewFileSet(), node) + if err != nil { + panic(err) + } + return s.String() +} diff --git a/symbol/symbols.go b/symbol/symbols.go new file mode 100644 index 0000000..960b241 --- /dev/null +++ b/symbol/symbols.go @@ -0,0 +1,83 @@ +package symbol + +import ( + sitter "github.com/smacker/go-tree-sitter" +) + +// Go reserved keywords that are not Java keywords, and create invalid code +var reservedKeywords = []string{"chan", "defer", "fallthrough", "func", "go", "map", "range", "select", "struct", "type"} + +// IsReserved tests if a given identifier conflicts with a Go reserved keyword +func IsReserved(name string) bool { + for _, keyword := range reservedKeywords { + if keyword == name { + return true + } + } + return false +} + +// TypeOfLiteral returns the corresponding type for a Java literal +func TypeOfLiteral(node *sitter.Node, source []byte) string { + var originalType string + + switch node.Type() { + case "decimal_integer_literal": + switch node.Content(source)[len(node.Content(source))-1] { + case 'L': + originalType = "long" + default: + originalType = "int" + } + case "hex_integer_literal": + panic("here") + case "decimal_floating_point_literal": + switch node.Content(source)[len(node.Content(source))-1] { + case 'D': + originalType = "double" + default: + originalType = "float" + } + case "string_literal": + originalType = "String" + case "character_literal": + originalType = "char" + } + + return originalType +} + +// ResolveDefinition resolves a given definition, given its scope in the file +// It returns `true` on a successful resolution, or `false` otherwise +// +// Resolving a definition means that the type of the file is matched up with the type defined +// in the local scope or otherwise +func ResolveDefinition(definition *Definition, fileScope *FileScope) bool { + // Look in the class scope first + //if localClassDef := fileScope.FindClass().ByType(definition.Type); localClassDef != nil { + if localClassDef := fileScope.BaseClass.FindClass(definition.Type); localClassDef != nil { + // Every type in the local scope is a reference type, so prefix it with a pointer + definition.Type = "*" + localClassDef.Name + return true + + } else if globalDef, in := fileScope.Imports[definition.Type]; in { // Look through the imports + // Find what package the type is in + if packageDef := GlobalScope.FindPackage(globalDef); packageDef != nil { + definition.Type = packageDef.FindClass(definition.Type).FindClass(definition.Type).Type + } + return true + } + + // Unresolved + return false +} + +// ResolveChildren recursively resolves a definition and all of its children +// It returns true if all definitions were resolved correctly, and false otherwise +func ResolveChildren(definition *Definition, fileScope *FileScope) bool { + result := ResolveDefinition(definition, fileScope) + for _, child := range definition.Children { + result = ResolveChildren(child, fileScope) && result + } + return result +} diff --git a/testfiles/NameCollisions.java b/testfiles/NameCollisions.java index d42b17d..9dbac24 100644 --- a/testfiles/NameCollisions.java +++ b/testfiles/NameCollisions.java @@ -1,3 +1,5 @@ +package com.example; + import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -9,13 +11,15 @@ public class NameCollisions { // `map` is a reserved keyword Map map; + Map test = new HashMap<>(); + // Since `type` is a reserved keyword in Go, this should fail public int getFruit(String type) { return this.fruitTypes.get(type); } // This is also a collision, with the keyword `range` - public int[] range() { + private int[] range() { int[] values = new int[this.map.size()]; int ind = 0; for (int val : this.map.values()) { @@ -33,6 +37,8 @@ public NameCollisions() { public static void main(String[] args) { NameCollisions test = new NameCollisions(); + System.out.println(test.map); + // Even more collisions Map map = new HashMap<>(); map.put("Apple", 1); diff --git a/testfiles/ResolveTesting.java b/testfiles/ResolveTesting.java new file mode 100644 index 0000000..a9d21a9 --- /dev/null +++ b/testfiles/ResolveTesting.java @@ -0,0 +1,33 @@ +public class ResolveTesting { + + // This should be changed + Node temp; + + private class Node { + int value; + + public Node(int value) { + this.value = value; + } + } + + // The parameter of this should change + public void incrementNode(Node target) { + target.value++; + } + + public int square(int x1, int x2) { + return x1 * x2; + } + + public Node add(Node n1, Node n2) { + Node temp = null; + temp = new Node(n1.value + n2.value); + return temp; + } + + // The parameter and return type should change + public Node duplicateNode(Node target) { + return new Node(target.value); + } +} diff --git a/testfiles/ScrambledForLoops.java b/testfiles/ScrambledForLoops.java index 8e639f2..cbcc433 100644 --- a/testfiles/ScrambledForLoops.java +++ b/testfiles/ScrambledForLoops.java @@ -30,11 +30,11 @@ public static void main(String[] args) { i++; } - System.out.println("Multiple declaration") + System.out.println("Multiple declaration"); int e; int f; - for (e = 1, f = 1; e < 3; i++) { - System.out.println(e) + for (e = 1, f = 1; e < 3; e++) { + System.out.println(e); } } } diff --git a/tree_sitter.go b/tree_sitter.go index d1cdc19..1b01d22 100644 --- a/tree_sitter.go +++ b/tree_sitter.go @@ -3,36 +3,18 @@ package main import ( "fmt" "go/ast" - "unicode" + "github.com/NickyBoy89/java2go/astutil" + "github.com/NickyBoy89/java2go/nodeutil" + "github.com/NickyBoy89/java2go/symbol" log "github.com/sirupsen/logrus" sitter "github.com/smacker/go-tree-sitter" ) -// Children gets all named children of a given node -func Children(node *sitter.Node) []*sitter.Node { - count := int(node.NamedChildCount()) - children := make([]*sitter.Node, count) - for i := 0; i < count; i++ { - children[i] = node.NamedChild(i) - } - return children -} - -// UnnamedChildren gets all the named + unnamed children of a given node -func UnnamedChildren(node *sitter.Node) []*sitter.Node { - count := int(node.ChildCount()) - children := make([]*sitter.Node, count) - for i := 0; i < count; i++ { - children[i] = node.Child(i) - } - return children -} - // Inspect is a function for debugging that prints out every named child of a // given node and the source code for that child func Inspect(node *sitter.Node, source []byte) { - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { fmt.Println(c, c.Content(source)) } } @@ -40,40 +22,51 @@ func Inspect(node *sitter.Node, source []byte) { // CapitalizeIdent capitalizes the first letter of a `*ast.Ident` to mark the // result as a public method or field func CapitalizeIdent(in *ast.Ident) *ast.Ident { - return &ast.Ident{Name: ToPublic(in.Name)} + return &ast.Ident{Name: symbol.Uppercase(in.Name)} } // LowercaseIdent lowercases the first letter of a `*ast.Ident` to mark the // result as a private method or field func LowercaseIdent(in *ast.Ident) *ast.Ident { - return &ast.Ident{Name: ToPrivate(in.Name)} -} - -// ToPublic uppercases the first character of the given string -func ToPublic(name string) string { - return string(unicode.ToUpper(rune(name[0]))) + name[1:] -} - -// ToPrivate lowercases the first character of the given string -func ToPrivate(name string) string { - return string(unicode.ToLower(rune(name[0]))) + name[1:] + return &ast.Ident{Name: symbol.Lowercase(in.Name)} } -// A Ctx is passed into the `ParseNode` function and contains any data that is -// needed down-the-line for parsing, such as the class's name +// A Ctx is all the context that is needed to parse a single source file type Ctx struct { // Used to generate the names of all the methods, as well as the names // of the constructors className string + + // Symbols for the current file being parsed + currentFile *symbol.FileScope + currentClass *symbol.ClassScope + + // The symbols of the current + localScope *symbol.Definition + // Used when generating arrays, because in Java, these are defined as // arrType[] varName = {item, item, item}, and no class name data is defined // Can either be of type `*ast.Ident` or `*ast.StarExpr` lastType ast.Expr } -// Parses a given tree-sitter node and returns the ast representation for it -// if called on the root of a tree-sitter node, it will return the entire -// generated golang ast as a `ast.Node` type +// Clone performs a shallow copy on a `Ctx`, returning a new Ctx with its pointers +// pointing at the same things as the previous Ctx +func (c Ctx) Clone() Ctx { + return Ctx{ + className: c.className, + currentFile: c.currentFile, + currentClass: c.currentClass, + localScope: c.localScope, + lastType: c.lastType, + } +} + +// ParseNode parses a given tree-sitter node and returns the ast representation +// +// This function is called when the node being parsed might not be a direct +// expression or statement, as those are parsed with `ParseExpr` and `ParseStmt` +// respectively func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { switch node.Type() { case "ERROR": @@ -88,7 +81,7 @@ func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { Name: &ast.Ident{Name: "main"}, } - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { switch c.Type() { case "package_declaration": program.Name = &ast.Ident{Name: c.NamedChild(0).NamedChild(int(c.NamedChild(0).NamedChildCount()) - 1).Content(source)} @@ -100,42 +93,29 @@ func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { } return program case "field_declaration": - var fieldType ast.Expr - var fieldName *ast.Ident - var public bool - var fieldOffset int - - for ind, c := range Children(node) { - switch c.Type() { - case "modifiers": // Ignore the modifiers for now - for _, modifier := range UnnamedChildren(c) { - if modifier.Type() == "public" { - public = true - } + if node.NamedChild(0).Type() == "modifiers" { + for _, modifier := range nodeutil.UnnamedChildrenOf(node.NamedChild(0)) { + if modifier.Type() == "public" { + public = true } - fieldOffset = ind + 1 } } - if fieldType == nil { - fieldType = ParseExpr(node.NamedChild(fieldOffset), source, ctx) - fieldName = ParseExpr(node.NamedChild(fieldOffset+1).NamedChild(0), source, ctx).(*ast.Ident) - } - - if public { - fieldName = CapitalizeIdent(fieldName) - } else { - fieldName = LowercaseIdent(fieldName) - } + fieldType := ParseExpr(node.ChildByFieldName("type"), source, ctx) + fieldName := ParseExpr(node.ChildByFieldName("declarator").ChildByFieldName("name"), source, ctx).(*ast.Ident) + fieldName.Name = symbol.HandleExportStatus(public, fieldName.Name) - // If the field had a value associated with it, (ex: variable = NewValue()) - if node.NamedChild(fieldOffset+1).NamedChildCount() > 1 { + // If the field is assigned to a value (ex: int field = 1) + fieldAssignmentNode := node.ChildByFieldName("declarator").ChildByFieldName("value") + if fieldAssignmentNode != nil { return &ast.ValueSpec{ - Names: []*ast.Ident{fieldName}, - Type: fieldType, - Values: []ast.Expr{ParseExpr(node.NamedChild(fieldOffset+1).NamedChild(1), source, ctx)}, + Names: []*ast.Ident{fieldName}, + Type: fieldType, + Values: []ast.Expr{ + ParseExpr(fieldAssignmentNode, source, ctx), + }, } } @@ -146,21 +126,14 @@ func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { case "import_declaration": return &ast.ImportSpec{Name: ParseExpr(node.NamedChild(0), source, ctx).(*ast.Ident)} case "method_declaration": - var public bool - comments := []*ast.Comment{} if node.NamedChild(0).Type() == "modifiers" { - cursor := sitter.NewTreeCursor(node.NamedChild(0)) - defer cursor.Close() - cursor.GoToFirstChild() - for cursor.GoToNextSibling() { - switch cursor.CurrentNode().Type() { - case "public": - public = true + for _, modifier := range nodeutil.UnnamedChildrenOf(node.NamedChild(0)) { + switch modifier.Type() { case "marker_annotation", "annotation": - comments = append(comments, &ast.Comment{Text: "//" + cursor.CurrentNode().Content(source)}) - if _, in := excludedAnnotations[cursor.CurrentNode().Content(source)]; in { + comments = append(comments, &ast.Comment{Text: "//" + modifier.Content(source)}) + if _, in := excludedAnnotations[modifier.Content(source)]; in { // If this entire method is ignored, we return an empty field, which // is handled by the logic that parses a class file return &ast.Field{} @@ -169,20 +142,52 @@ func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { } } - name := LowercaseIdent(ParseExpr(node.ChildByFieldName("name"), source, ctx).(*ast.Ident)) + parameters := &ast.FieldList{} + + for _, param := range nodeutil.NamedChildrenOf(node.ChildByFieldName("parameters")) { + parameters.List = append(parameters.List, ParseNode(param, source, ctx).(*ast.Field)) + } + + methodName := node.ChildByFieldName("name").Content(source) + methodParameters := node.ChildByFieldName("parameters") + + comparison := func(d *symbol.Definition) bool { + // The names must match + if methodName != d.OriginalName { + return false + } + + // Size of parameters must match + if int(methodParameters.NamedChildCount()) != len(d.Parameters) { + return false + } + + // Go through the types and check to see if they differ + for index, param := range nodeutil.NamedChildrenOf(methodParameters) { + var paramType string + if param.Type() == "spread_parameter" { + paramType = param.NamedChild(0).Content(source) + } else { + paramType = param.ChildByFieldName("type").Content(source) + } + if paramType != d.Parameters[index].OriginalType { + return false + } + } - if public { - name = CapitalizeIdent(name) + return true } + def := ctx.currentClass.FindMethod().By(comparison)[0] + return &ast.Field{ Doc: &ast.CommentGroup{List: comments}, - Names: []*ast.Ident{name}, + Names: []*ast.Ident{&ast.Ident{Name: def.Name}}, Type: &ast.FuncType{ - Params: ParseNode(node.ChildByFieldName("parameters"), source, ctx).(*ast.FieldList), + Params: parameters, Results: &ast.FieldList{List: []*ast.Field{ &ast.Field{ - Type: ParseExpr(node.ChildByFieldName("type"), source, ctx), + Type: &ast.Ident{Name: def.Type}, }, }, }, @@ -211,61 +216,51 @@ func ParseNode(node *sitter.Node, source []byte, ctx Ctx) interface{} { return &ast.CaseClause{} case "argument_list": args := []ast.Expr{} - for _, c := range Children(node) { + for _, c := range nodeutil.NamedChildrenOf(node) { args = append(args, ParseExpr(c, source, ctx)) } return args case "formal_parameters": params := &ast.FieldList{} - for _, param := range Children(node) { + for _, param := range nodeutil.NamedChildrenOf(node) { params.List = append(params.List, ParseNode(param, source, ctx).(*ast.Field)) } return params case "formal_parameter": - // If the parameter has an annotation, ignore that - offset := 0 - if node.NamedChild(0).Type() == "modifiers" { - offset = 1 + if ctx.localScope != nil { + paramDef := ctx.localScope.ParameterByName(node.ChildByFieldName("name").Content(source)) + if paramDef == nil { + paramDef = &symbol.Definition{ + Name: node.ChildByFieldName("name").Content(source), + Type: node.ChildByFieldName("type").Content(source), + } + } + return &ast.Field{ + Names: []*ast.Ident{&ast.Ident{Name: paramDef.Name}}, + Type: &ast.Ident{Name: paramDef.Type}, + } } return &ast.Field{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(offset+1), source, ctx).(*ast.Ident)}, - Type: ParseExpr(node.NamedChild(offset), source, ctx), + Names: []*ast.Ident{ParseExpr(node.ChildByFieldName("name"), source, ctx).(*ast.Ident)}, + Type: astutil.ParseType(node.ChildByFieldName("type"), source), } case "spread_parameter": // The spread paramater takes a list and separates it into multiple elements // Ex: addElements([]int elements...) - switch ParseExpr(node.NamedChild(0), source, ctx).(type) { - case *ast.StarExpr: - // If the parameter is a reference type (ex: ...[]*Test), then the type is - // a `StarExpr`, which is passed into the ellipsis - return &ast.Field{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(1).NamedChild(0), source, ctx).(*ast.Ident)}, - Type: &ast.Ellipsis{ - Elt: ParseExpr(node.NamedChild(0), source, ctx), - }, - } - case *ast.ArrayType: - // Represents something such as `byte[]... name` - return &ast.Field{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(1).NamedChild(0), source, ctx).(*ast.Ident)}, - Type: &ast.Ellipsis{ - Elt: ParseExpr(node.NamedChild(0), source, ctx), - }, - } - } + spreadType := node.NamedChild(0) + spreadDeclarator := node.NamedChild(1) return &ast.Field{ - Names: []*ast.Ident{ParseExpr(node.NamedChild(0), source, ctx).(*ast.Ident)}, + Names: []*ast.Ident{ParseExpr(spreadDeclarator.ChildByFieldName("name"), source, ctx).(*ast.Ident)}, Type: &ast.Ellipsis{ - // This comes as a variable declarator, but we only need need the identifier for the type - Elt: ParseExpr(node.NamedChild(1).NamedChild(0), source, ctx), + Elt: astutil.ParseType(spreadType, source), }, } case "inferred_parameters": params := &ast.FieldList{} - for _, param := range Children(node) { + for _, param := range nodeutil.NamedChildrenOf(node) { params.List = append(params.List, &ast.Field{ Names: []*ast.Ident{ParseExpr(param, source, ctx).(*ast.Ident)}, // When we're not sure what parameters to infer, set them as interface diff --git a/typecheck.go b/typecheck.go deleted file mode 100644 index 2a5c961..0000000 --- a/typecheck.go +++ /dev/null @@ -1,137 +0,0 @@ -package main - -import ( - "bytes" - "go/ast" - "go/printer" - "go/token" - - sitter "github.com/smacker/go-tree-sitter" -) - -// ClassScope contains the global and local scopes for a single file -// if a file contains multiply classes, all the definitions are folded into -// one ClassScope -type ClassScope struct { - Classes []*Definition - Fields []*Definition - Methods []*Definition -} - -// A Definition contains information about a single entry -type Definition struct { - // The original java name - OriginalName string - // The display name, may be different from the original name - Name string - // Type of the object - Type string - // If the definition is a constructor - Constructor bool - // If the object is a function, it has parameters - Parameters []*Definition - // Children of the declaration, if the declaration is a scope - Children []*Definition -} - -// Rename renames a definition for a type so that it can be referenced later with -// the correct name -func (d *Definition) Rename(name string) { - d.Name = name -} - -func (d Definition) isEmpty() bool { - return d.OriginalName == "" && len(d.Children) == 0 -} - -func nodeToStr(node any) string { - var s bytes.Buffer - err := printer.Fprint(&s, token.NewFileSet(), node) - if err != nil { - panic(err) - } - return s.String() -} - -// ExtractDefinitions generates a symbol table containing all the definitions -// for a single input file -func ExtractDefinitions(root *sitter.Node, source []byte) *ClassScope { - return parseClassScope(root.NamedChild(0), source) -} - -func parseClassScope(root *sitter.Node, source []byte) *ClassScope { - className := nodeToStr(ParseExpr(root.ChildByFieldName("name"), source, Ctx{})) - scope := &ClassScope{ - Classes: []*Definition{ - &Definition{ - OriginalName: className, - Name: className, - }, - }, - } - - var node *sitter.Node - for i := 0; i < int(root.ChildByFieldName("body").NamedChildCount()); i++ { - node = root.ChildByFieldName("body").NamedChild(i) - switch node.Type() { - case "field_declaration": - name := nodeToStr(ParseExpr(node.ChildByFieldName("declarator").ChildByFieldName("name"), source, Ctx{})) - scope.Fields = append(scope.Fields, &Definition{ - OriginalName: name, - Type: nodeToStr(ParseExpr(node.ChildByFieldName("type"), source, Ctx{})), - Name: nodeToStr(ParseExpr(node.ChildByFieldName("declarator").ChildByFieldName("name"), source, Ctx{})), - }) - case "method_declaration", "constructor_declaration": - name := nodeToStr(ParseExpr(node.ChildByFieldName("name"), source, Ctx{})) - declaration := &Definition{ - OriginalName: name, - Name: name, - } - - if node.Type() == "method_declaration" { - declaration.Type = nodeToStr(ParseExpr(node.ChildByFieldName("type"), source, Ctx{})) - } else { - // A constructor returns itself, so it does not have a type - declaration.Constructor = true - } - - for _, param := range ParseNode(node.ChildByFieldName("parameters"), source, Ctx{}).(*ast.FieldList).List { - name := nodeToStr(param.Names[0]) - declaration.Parameters = append(declaration.Parameters, &Definition{ - OriginalName: name, - Type: nodeToStr(param.Type), - Name: name, - }) - } - - methodScope := parseScope(node.ChildByFieldName("body"), source) - if !methodScope.isEmpty() { - declaration.Children = append(declaration.Children, methodScope) - } - - scope.Methods = append(scope.Methods, declaration) - } - } - - return scope -} - -func parseScope(root *sitter.Node, source []byte) *Definition { - def := &Definition{} - var node *sitter.Node - for i := 0; i < int(root.NamedChildCount()); i++ { - node = root.NamedChild(i) - switch node.Type() { - case "local_variable_declaration": - name := nodeToStr(ParseExpr(node.ChildByFieldName("declarator").ChildByFieldName("name"), source, Ctx{})) - def.Children = append(def.Children, &Definition{ - OriginalName: name, - Type: nodeToStr(ParseExpr(node.ChildByFieldName("type"), source, Ctx{})), - Name: name, - }) - case "for_statement", "enhanced_for_statement", "while_statement", "if_statement": - def.Children = append(def.Children, parseScope(node, source)) - } - } - return def -}