Skip to content

Commit

Permalink
Tag names with definition type, add tests for recursive structs (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
keynmol authored Jan 29, 2022
1 parent 8b308c4 commit 7e1af01
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 162 deletions.
28 changes: 20 additions & 8 deletions bindgen/src/main/scala/Def.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,26 @@ import Def.*

case class BindingDefinition(item: Def, isFromMainFile: Boolean)

enum DefTag:
case Union, Alias, Struct, Function, Enum
object DefTag:
import DefTag.*
def all = Set(Union, Alias, Struct, Function, Enum)

case class DefName(n: String, tg: DefTag)

case class Binding(
var named: mutable.Map[String, BindingDefinition] = mutable.Map.empty
var named: mutable.Map[DefName, BindingDefinition] = mutable.Map.empty
):
def add(item: Def, isFromMainFile: Boolean) =
item.defName.foreach { n =>
named.addOne(n -> BindingDefinition(item, isFromMainFile))
}
this

def remove(name: DefName): Binding =
named.remove(name)
this
def aliases: Set[Def.Alias] = named.collect {
case (k, BindingDefinition(item: Def.Alias, _)) => item
}.toSet
Expand Down Expand Up @@ -57,14 +70,13 @@ enum Def:
)
case Alias(name: String, underlying: CType)

def defName: Option[String] =
def defName: Option[DefName] =
this match
case Alias(name, _) => Some(name)
case Union(_, name) => Some(name)
case f: Function => Some(f.name)
case s: Struct => Some(s.name)
case e: Enum => e.name

case Alias(name, _) => Some(DefName(name, DefTag.Alias))
case Union(_, name) => Some(DefName(name, DefTag.Union))
case f: Function => Some(DefName(f.name, DefTag.Function))
case s: Struct => Some(DefName(s.name, DefTag.Struct))
case e: Enum => e.name.map(DefName(_, DefTag.Enum))
end Def

object Def:
Expand Down
9 changes: 7 additions & 2 deletions bindgen/src/main/scala/analysis/analyse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,24 @@ def analyse(file: String)(using Zone)(using config: Config): Binding =

trace(s"Defined or used in main file: ${closure}")

binding.named.filterInPlace((k, _) => closure.contains(k))
binding.named.filterInPlace((k, _) => closure.contains(k.n))

trace("Binding information:")
binding.named.toList.sortBy(_._1).foreach { case (k, v) =>
binding.named.toList.sortBy(_._1.n).foreach { case (k, v) =>
trace(s"'$k': $v")
}

binding
end analyse

def addBuiltInAliases(binding: Binding): Binding =
val replaceTypes = DefTag.all - DefTag.Function
BuiltinType.all.foreach { tpe =>
val al = Def.Alias(tpe.short, CType.Reference(Name.BuiltIn(tpe)))
replaceTypes.foreach { tg =>
binding.remove(DefName(tpe.short, tg))
}
binding.add(al, isFromMainFile = false)
}
binding
end addBuiltInAliases
19 changes: 12 additions & 7 deletions bindgen/src/main/scala/analysis/closure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,30 @@ def definitionClosure(b: Def)(using Config): Set[String] =
end match
end definitionClosure

def computeClosure(named: Map[String, BindingDefinition])(using
def computeClosure(named: Map[DefName, BindingDefinition])(using
Config
): Set[String] =
import scala.collection.mutable

def expand(visited: Set[String], result: Set[String]): Set[String] =
val notVisited = result -- visited
trace(s"Closure computer: visited = $visited, result = $result")
// trace(s"Closure computer: visited = $visited, result = $result")
if notVisited.isEmpty then result
else
val grown = notVisited.flatMap(k =>
named.get(k).map(n => definitionClosure(n.item)).getOrElse(Set.empty)
)
trace(s"Closure computer: grown = $grown")
val grown = notVisited.flatMap { k =>
DefTag.all.flatMap { tg =>
named
.get(DefName(k, tg))
.map(n => definitionClosure(n.item))
.getOrElse(Set.empty)
}
}
// trace(s"Closure computer: grown = $grown")
if (grown -- result).nonEmpty then
expand(visited ++ notVisited, result ++ grown)
else result
end if
end expand

expand(Set.empty, named.filter(_._2.isFromMainFile).keySet)
expand(Set.empty, named.filter(_._2.isFromMainFile).keySet.map(_.n))
end computeClosure
9 changes: 9 additions & 0 deletions bindgen/src/main/scala/render/binding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def binding(
how: (A, Appender) => Unit
) =
defs.zipWithIndex.foreach { case (en, idx) =>
en match
case df: Def =>
df.defName.foreach { name =>
trace(s"Rendering $name")
}
case sf: GeneratedFunction.ScalaFunction =>
trace(s"Rendering Scala function '${sf.name}'")
case sf: GeneratedFunction.CFunction =>
trace(s"Rendering C function '${sf.name}'")
try how(
en,
to(out)
Expand Down
13 changes: 7 additions & 6 deletions bindgen/src/main/scala/render/hack_recursive_structs.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package bindgen.rendering

import bindgen.*
import scala.annotation.tailrec

def isCyclical(typ: CType, name: String)(using AliasResolver, Config): Boolean =
def go(t: CType, visited: Set[String], level: Int): Boolean =
import CType.*
trace((" " * level) + s"visiting $t with $visited")
// trace((" " * level) + s"visiting $t with $visited")
val result =
t match
case Reference(Name.Model(name)) =>
Expand Down Expand Up @@ -37,7 +38,7 @@ def isCyclical(typ: CType, name: String)(using AliasResolver, Config): Boolean =
end match
end result

trace((" " * level) + s"result of $t is '$result', visited: $visited")
// trace((" " * level) + s"result of $t is '$result', visited: $visited")
result
end go
go(typ, Set(name), 0)
Expand Down Expand Up @@ -81,10 +82,10 @@ def hack_recursive_structs(
aliasResolver(name) match
case Pointer(_: Function) => result(Reference(Name.Model(name)))

case other =>
throw Error(
s"Expected '$name' to point to a function pointer, got $other instead"
)
case other => None
// throw Error(
// s"Expected '$name' to point to a function pointer, got $other instead"
// )
end match

end if
Expand Down
23 changes: 23 additions & 0 deletions bindgen/src/test/resources/scala-native/recursive_structs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
typedef struct Recursive_Struct2 Recursive_Struct2;

typedef void (*hello_func)(struct Recursive_Struct2 *);

typedef struct Recursive_Struct2 {
struct Recrusive_Struct1 *hello2;
char *str;
} Recursive_Struct2;

typedef struct Recursive_Struct3 {
hello_func handler;
int two;
} Recursive_Struct3;

typedef struct Recrusive_Struct1 {
Recursive_Struct3 *hello;
double d;
} Recrusive_Struct1;

typedef struct Recrusive_Simple {
struct Recrusive_Simple *hello;
double d;
} Recrusive_Simple;
42 changes: 42 additions & 0 deletions bindgen/src/test/scala/TestRecursiveStructs.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package bindgen

import org.junit.Assert.*
import org.junit.Test

import scala.scalanative.unsafe.*
import scala.scalanative.unsigned.*

class TestRecursiveStructs:
import lib_test_recursive_structs.types.*
@Test def test_simple_recursive(): Unit =
zone {
val nested3 = Recrusive_Simple(null, 333.33)
val nested2 = Recrusive_Simple(nested3, 222.22)
val nested1 = Recrusive_Simple(nested2, 111.11)
val nested0 = Recrusive_Simple(nested1, 0.0)

assertEquals(111.11, (!(!nested0).hello).d, 0.01d)
assertEquals(222.22, (!(!nested1).hello).d, 0.01d)
assertEquals(333.33, (!(!nested2).hello).d, 0.01d)
assertEquals(null, (!nested3).hello)

(!(!nested0).hello).hello = nested3
assertEquals(
nested3.unary_!.d,
nested0.unary_!.hello.unary_!.hello.unary_!.d,
0.01d
)
}

@Test def test_mutually_recursive(): Unit =
zone {
val struct1 = Recrusive_Struct1(null, 1.0)
val struct2 = Recursive_Struct2(struct1, c"hello")
val struct3 = Recursive_Struct3(hello_func(null), 25)
(!struct1).hello = struct3

assertEquals(1.0, struct1.unary_!.d, 0.01d)
assertEquals('h', struct2.unary_!.str(0))

}
end TestRecursiveStructs
11 changes: 9 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def sampleBindings(location: File, builder: BindingBuilder, ci: ClangInfo) = {
val llvmInclude = ci.llvmInclude.map("-I" + _)

define(location / "cJSON.h", "libcjson", Some("cjson"), List("cJSON.h"))
define(location / "test.h", "libtest", Some("test"), List("test.h"))

define(
location /
"Clang-Index.h",
Expand All @@ -391,7 +391,6 @@ def sampleBindings(location: File, builder: BindingBuilder, ci: ClangInfo) = {
List("clang-c/Index.h"),
llvmInclude
)

define(
location /
"tree-sitter.h",
Expand All @@ -411,6 +410,14 @@ def sampleBindings(location: File, builder: BindingBuilder, ci: ClangInfo) = {
List("raylib.h"),
llvmInclude ++ clangInclude
)
/* define( */
/* location / */
/* "sokol_gfx.h", */
/* "sokol_gfx", */
/* Some("sokol_gfx"), */
/* List("sokol_gfx.h"), */
/* llvmInclude ++ clangInclude */
/* ) */

if (Platform.target.os == Platform.OS.MacOS)
define(
Expand Down
Loading

0 comments on commit 7e1af01

Please sign in to comment.